【機器學習】四種超參數(shù)搜索方法
在建模時模型的超參數(shù)對精度有一定的影響,而設置和調(diào)整超參數(shù)的取值,往往稱為調(diào)參。
在實踐中調(diào)參往往依賴人工來進行設置調(diào)整范圍,然后使用機器在超參數(shù)范圍內(nèi)進行搜素。本文將演示在sklearn中支持的四種基礎超參數(shù)搜索方法:
-
GridSearch -
RandomizedSearch -
HalvingGridSearch -
HalvingRandomSearch
原始模型
作為精度對比,我們最開始使用隨機森林來訓練初始化模型,并在測試集計算精度:
#?數(shù)據(jù)讀取
df?=?pd.read_csv('https://mirror.coggle.club/dataset/heart.csv')
X?=?df.drop(columns=['output'])
y?=?df['output']
#?數(shù)據(jù)劃分
x_train,?x_test,?y_train,?y_test?=?train_test_split(X,?y,?stratify=y)
#?模型訓練與計算準確率
clf?=?RandomForestClassifier(random_state=0)
clf.fit(x_train,?y_train)
clf.score(x_test,?y_test)
模型最終在測試集精度為:0.802。
GridSearch
GridSearch是比較基礎的超參數(shù)搜索方法,中文名字網(wǎng)格搜索。其原理是在計算的過程中遍歷所有的超參數(shù)組合,然后搜索到最優(yōu)的結果。
如下代碼所示,我們對4個超參數(shù)進行搜索,搜索空間為 5 * 3 * 2 * 3 = 90組超參數(shù)。對于每組超參數(shù)還需要計算5折交叉驗證,則需要訓練450次。
parameters?=?{
????'max_depth':?[2,4,5,6,7],
????'min_samples_leaf':?[1,2,3],
????'min_weight_fraction_leaf':?[0,?0.1],
????'min_impurity_decrease':?[0,?0.1,?0.2]
}
#?Fitting?5?folds?for?each?of?90?candidates,?totalling?450?fits
clf?=?GridSearchCV(
????RandomForestClassifier(random_state=0),
????parameters,?refit=True,?verbose=1,
)
clf.fit(x_train,?y_train)
clf.best_estimator_.score(x_test,?y_test)
模型最終在測試集精度為:0.815。
RandomizedSearch
RandomizedSearch是在一定范圍內(nèi)進行搜索,且需要設置搜索的次數(shù),其默認不會對所有的組合進行搜索。
n_iter代表超參數(shù)組合的個數(shù),默認會設置比所有組合次數(shù)少的取值,如下面設置的為10,則只進行50次訓練。
parameters?=?{
????'max_depth':?[2,4,5,6,7],
????'min_samples_leaf':?[1,2,3],
????'min_weight_fraction_leaf':?[0,?0.1],
????'min_impurity_decrease':?[0,?0.1,?0.2]
}
clf?=?RandomizedSearchCV(
????RandomForestClassifier(random_state=0),
????parameters,?refit=True,?verbose=1,?n_iter=10,
)
clf.fit(x_train,?y_train)
clf.best_estimator_.score(x_test,?y_test)
模型最終在測試集精度為:0.815。
HalvingGridSearch
HalvingGridSearch和GridSearch非常相似,但在迭代的過程中是有參數(shù)組合減半的操作。
最開始使用所有的超參數(shù)組合,但使用最少的數(shù)據(jù),篩選其中最優(yōu)的超參數(shù),增加數(shù)據(jù)再進行篩選。
HalvingGridSearch的思路和hyperband的思路非常相似,但是最樸素的實現(xiàn)。先使用少量數(shù)據(jù)篩選超參數(shù)組合,然后使用更多的數(shù)據(jù)驗證精度。
n_iterations:?3
n_required_iterations:?5
n_possible_iterations:?3
min_resources_:?20
max_resources_:?227
aggressive_elimination:?False
factor:?3
----------
iter:?0
n_candidates:?90
n_resources:?20
Fitting?5?folds?for?each?of?90?candidates,?totalling?450?fits
----------
iter:?1
n_candidates:?30
n_resources:?60
Fitting?5?folds?for?each?of?30?candidates,?totalling?150?fits
----------
iter:?2
n_candidates:?10
n_resources:?180
Fitting?5?folds?for?each?of?10?candidates,?totalling?50?fits
----------
模型最終在測試集精度為:0.855。
HalvingRandomSearch
HalvingRandomSearch和HalvingGridSearch類似,都是逐步增加樣本,減少超參數(shù)組合。但每次生成超參數(shù)組合,都是隨機篩選的。
n_iterations:?3
n_required_iterations:?3
n_possible_iterations:?3
min_resources_:?20
max_resources_:?227
aggressive_elimination:?False
factor:?3
----------
iter:?0
n_candidates:?11
n_resources:?20
Fitting?5?folds?for?each?of?11?candidates,?totalling?55?fits
----------
iter:?1
n_candidates:?4
n_resources:?60
Fitting?5?folds?for?each?of?4?candidates,?totalling?20?fits
----------
iter:?2
n_candidates:?2
n_resources:?180
Fitting?5?folds?for?each?of?2?candidates,?totalling?10?fits
模型最終在測試集精度為:0.828。
總結與對比
HalvingGridSearch和HalvingRandomSearch比較適合在數(shù)據(jù)量比較大的情況使用,可以提高訓練速度。如果計算資源充足,GridSearch和HalvingGridSearch會得到更好的結果。
后續(xù)我們將分享其他的一些高階調(diào)參庫的實現(xiàn),其中也會有數(shù)據(jù)量改變的思路。如在Optuna中,核心是參數(shù)組合的生成和剪枝、訓練的樣本增加等細節(jié)。
往期
精彩
回顧
- 適合初學者入門人工智能的路線及資料下載
- (圖文+視頻)機器學習入門系列下載
- 機器學習及深度學習筆記等資料打印
- 《統(tǒng)計學習方法》的代碼復現(xiàn)專輯
- 機器學習交流qq群955171419,加入微信群請 掃碼
