<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【深度學(xué)習(xí)】神經(jīng)網(wǎng)絡(luò)模型特征重要性可以查看了?。?!

          共 1443字,需瀏覽 3分鐘

           ·

          2021-10-13 07:42

          作者:杰少

          查看NN模型特征重要性的技巧

          簡(jiǎn) 介

          我們都知道樹(shù)模型的特征重要性是非常容易繪制出來(lái)的,只需要直接調(diào)用樹(shù)模型自帶的API即可以得到在樹(shù)模型中每個(gè)特征的重要性,那么對(duì)于神經(jīng)網(wǎng)絡(luò)我們?cè)撊绾蔚玫狡涮卣髦匾阅兀?/span>

          本篇文章我們就以LSTM為例,來(lái)介紹神經(jīng)網(wǎng)絡(luò)中模型特征重要性的一種獲取方式。

          NN模型特征重要性

          01


          基本思路

          該策略的思想來(lái)源于:Permutation Feature Importance,我們以特征對(duì)于模型最終預(yù)測(cè)結(jié)果的變化來(lái)衡量特征的重要性。

          02


          實(shí)現(xiàn)步驟


          NN模型特征重要性的獲取步驟如下:

          1. 訓(xùn)練一個(gè)NN;
          2. 每次獲取一個(gè)特征列,然后對(duì)其進(jìn)行隨機(jī)shuffle,使用模型對(duì)其進(jìn)行預(yù)測(cè)并得到Loss;
          3. 記錄每個(gè)特征列以及其對(duì)應(yīng)的Loss;
          4. 每個(gè)Loss就是該特征對(duì)應(yīng)的特征重要性,如果Loss越大,說(shuō)明該特征對(duì)于NN模型越加重要;反之,則越加不重要。
          Code

          代碼摘自:https://www.kaggle.com/cdeotte/lstm-feature-importance/notebook

          import?matplotlib.pyplot?as?plt
          from?tqdm.notebook?import?tqdm

          import?tensorflow?as?tf
          from?tensorflow?import?keras
          import?tensorflow.keras.backend?as?K
          from?tensorflow.keras.callbacks?import?EarlyStopping,?ModelCheckpoint
          from?tensorflow.keras.callbacks?import?LearningRateScheduler,?ReduceLROnPlateau
          from?tensorflow.keras.optimizers.schedules?import?ExponentialDecay
          from?sklearn.metrics?import?mean_absolute_error?as?mae
          from?sklearn.preprocessing?import?RobustScaler,?normalize
          from?sklearn.model_selection?import?train_test_split,?GroupKFold,?KFold
          from?IPython.display?import?display

          COMPUTE_LSTM_IMPORTANCE?=?1
          ONE_FOLD_ONLY?=?1

          with?gpu_strategy.scope():
          ????kf?=?KFold(n_splits=NUM_FOLDS,?shuffle=True,?random_state=2021)
          ????test_preds?=?[]
          ????for?fold,?(train_idx,?test_idx)?in?enumerate(kf.split(train,?targets)):
          ????????K.clear_session()
          ????????
          ????????print('-'*15,?'>',?f'Fold?{fold+1}',?'<',?'-'*15)
          ????????X_train,?X_valid?=?train[train_idx],?train[test_idx]
          ????????y_train,?y_valid?=?targets[train_idx],?targets[test_idx]
          ????????
          ????????#?導(dǎo)入已經(jīng)訓(xùn)練好的模型
          ????????model?=?keras.models.load_model('models/XXX.h5')
          ????????#?計(jì)算特征重要性
          ????????if?COMPUTE_LSTM_IMPORTANCE:
          ????????????results?=?[]
          ????????????print('?Computing?LSTM?feature?importance...')

          ????????????for?k?in?tqdm(range(len(COLS))):
          ????????????????if?k>0:?
          ????????????????????save_col?=?X_valid[:,:,k-1].copy()
          ????????????????????np.random.shuffle(X_valid[:,:,k-1])
          ????????????????????????
          ????????????????oof_preds?=?model.predict(X_valid,?verbose=0).squeeze()?
          ????????????????mae?=?np.mean(np.abs(?oof_preds-y_valid?))
          ????????????????results.append({'feature':COLS[k],'mae':mae})
          ????????
          ????????????????if?k>0:?
          ????????????????????X_valid[:,:,k-1]?=?save_col
          ?????????
          ????????????#?展示特征重要性
          ????????????print()
          ????????????df?=?pd.DataFrame(results)
          ????????????df?=?df.sort_values('mae')
          ????????????plt.figure(figsize=(10,20))
          ????????????plt.barh(np.arange(len(COLS)),df.mae)
          ????????????plt.yticks(np.arange(len(COLS)),df.feature.values)
          ????????????plt.title('LSTM?Feature?Importance',size=16)
          ????????????plt.ylim((-1,len(COLS)))
          ????????????plt.show()
          ???????????????????????????????
          ????????????#?SAVE?LSTM?FEATURE?IMPORTANCE
          ????????????df?=?df.sort_values('mae',ascending=False)
          ????????????df.to_csv(f'lstm_feature_importance_fold_{fold}.csv',index=False)
          ???????????????????????????????
          ????????#?ONLY?DO?ONE?FOLD
          ????????if?ONE_FOLD_ONLY:?break


          適用情況
          適用于所有的NN模型。
          參考文獻(xiàn)
          1. https://www.kaggle.com/cdeotte/lstm-feature-importance/notebook
          2. Permutation Feature Importance
          往期精彩回顧




          本站qq群851320808,加入微信群請(qǐng)掃碼:
          瀏覽 198
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  一区二区视频免费 | 天天日天天草 | 欧美精品久久久久久久久 | 三级成人AV在线电影 | 噜噜射亚洲 |