【深度學(xué)習(xí)】神經(jīng)網(wǎng)絡(luò)模型特征重要性可以查看了?。?!
查看NN模型特征重要性的技巧

我們都知道樹(shù)模型的特征重要性是非常容易繪制出來(lái)的,只需要直接調(diào)用樹(shù)模型自帶的API即可以得到在樹(shù)模型中每個(gè)特征的重要性,那么對(duì)于神經(jīng)網(wǎng)絡(luò)我們?cè)撊绾蔚玫狡涮卣髦匾阅兀?/span>
本篇文章我們就以LSTM為例,來(lái)介紹神經(jīng)網(wǎng)絡(luò)中模型特征重要性的一種獲取方式。

基本思路
該策略的思想來(lái)源于:Permutation Feature Importance,我們以特征對(duì)于模型最終預(yù)測(cè)結(jié)果的變化來(lái)衡量特征的重要性。
實(shí)現(xiàn)步驟
NN模型特征重要性的獲取步驟如下:
訓(xùn)練一個(gè)NN; 每次獲取一個(gè)特征列,然后對(duì)其進(jìn)行隨機(jī)shuffle,使用模型對(duì)其進(jìn)行預(yù)測(cè)并得到Loss; 記錄每個(gè)特征列以及其對(duì)應(yīng)的Loss; 每個(gè)Loss就是該特征對(duì)應(yīng)的特征重要性,如果Loss越大,說(shuō)明該特征對(duì)于NN模型越加重要;反之,則越加不重要。

代碼摘自:https://www.kaggle.com/cdeotte/lstm-feature-importance/notebook
import?matplotlib.pyplot?as?plt
from?tqdm.notebook?import?tqdm
import?tensorflow?as?tf
from?tensorflow?import?keras
import?tensorflow.keras.backend?as?K
from?tensorflow.keras.callbacks?import?EarlyStopping,?ModelCheckpoint
from?tensorflow.keras.callbacks?import?LearningRateScheduler,?ReduceLROnPlateau
from?tensorflow.keras.optimizers.schedules?import?ExponentialDecay
from?sklearn.metrics?import?mean_absolute_error?as?mae
from?sklearn.preprocessing?import?RobustScaler,?normalize
from?sklearn.model_selection?import?train_test_split,?GroupKFold,?KFold
from?IPython.display?import?display
COMPUTE_LSTM_IMPORTANCE?=?1
ONE_FOLD_ONLY?=?1
with?gpu_strategy.scope():
????kf?=?KFold(n_splits=NUM_FOLDS,?shuffle=True,?random_state=2021)
????test_preds?=?[]
????for?fold,?(train_idx,?test_idx)?in?enumerate(kf.split(train,?targets)):
????????K.clear_session()
????????
????????print('-'*15,?'>',?f'Fold?{fold+1}',?'<',?'-'*15)
????????X_train,?X_valid?=?train[train_idx],?train[test_idx]
????????y_train,?y_valid?=?targets[train_idx],?targets[test_idx]
????????
????????#?導(dǎo)入已經(jīng)訓(xùn)練好的模型
????????model?=?keras.models.load_model('models/XXX.h5')
????????#?計(jì)算特征重要性
????????if?COMPUTE_LSTM_IMPORTANCE:
????????????results?=?[]
????????????print('?Computing?LSTM?feature?importance...')
????????????for?k?in?tqdm(range(len(COLS))):
????????????????if?k>0:?
????????????????????save_col?=?X_valid[:,:,k-1].copy()
????????????????????np.random.shuffle(X_valid[:,:,k-1])
????????????????????????
????????????????oof_preds?=?model.predict(X_valid,?verbose=0).squeeze()?
????????????????mae?=?np.mean(np.abs(?oof_preds-y_valid?))
????????????????results.append({'feature':COLS[k],'mae':mae})
????????
????????????????if?k>0:?
????????????????????X_valid[:,:,k-1]?=?save_col
?????????
????????????#?展示特征重要性
????????????print()
????????????df?=?pd.DataFrame(results)
????????????df?=?df.sort_values('mae')
????????????plt.figure(figsize=(10,20))
????????????plt.barh(np.arange(len(COLS)),df.mae)
????????????plt.yticks(np.arange(len(COLS)),df.feature.values)
????????????plt.title('LSTM?Feature?Importance',size=16)
????????????plt.ylim((-1,len(COLS)))
????????????plt.show()
???????????????????????????????
????????????#?SAVE?LSTM?FEATURE?IMPORTANCE
????????????df?=?df.sort_values('mae',ascending=False)
????????????df.to_csv(f'lstm_feature_importance_fold_{fold}.csv',index=False)
???????????????????????????????
????????#?ONLY?DO?ONE?FOLD
????????if?ONE_FOLD_ONLY:?break


https://www.kaggle.com/cdeotte/lstm-feature-importance/notebook Permutation Feature Importance
往期精彩回顧 本站qq群851320808,加入微信群請(qǐng)掃碼:
評(píng)論
圖片
表情
