神經(jīng)網(wǎng)絡(luò)模型特征重要性-谷歌解決方案
背景
樹模型的特征重要性是相對(duì)容易計(jì)算出來的,那么對(duì)于神經(jīng)網(wǎng)絡(luò)我們?cè)撊绾蔚玫狡涮卣髦匾阅兀烤W(wǎng)上的方法包括permutation importance, null importance, 隨機(jī)對(duì)特征進(jìn)行mask等方法,本文要介紹的是牛津大學(xué)和谷歌提出的基于Gated Residual Networks (GRN) and Variable Selection Networks (VSN)的特征重要性計(jì)算方法。
使用GRN計(jì)算特征重要性的基本邏輯
1
提供特征列和target列,特征根據(jù)數(shù)據(jù)類型指定為數(shù)值型或離散型;
2
將數(shù)據(jù)劃分為驗(yàn)證集和訓(xùn)練集;
3
在訓(xùn)練集上,根據(jù)第一步提供的列定義,對(duì)數(shù)值型特征分別進(jìn)行歸一化,對(duì)離散型特征進(jìn)行embedding(相當(dāng)于給每一列離散特征創(chuàng)建一個(gè)類別詞典),然后將兩者拼接后傳給GRN模塊,計(jì)算得到每個(gè)特征的權(quán)重,再將權(quán)重和前面拼接后的結(jié)果按元素相乘,最后接全連接層,獲得預(yù)測結(jié)果,與真實(shí)值計(jì)算loss,迭代訓(xùn)練。
其中,GRN模塊獲取拼接輸入后,分成兩路,其中一路經(jīng)過多層變換最后額外使用sigmoid激活函數(shù)作為“門”對(duì)變換結(jié)果進(jìn)行選擇性加權(quán),另一路則作為residual connection直接與前者的輸出相加,起到防止過擬合的作用。
4
取驗(yàn)證集得分最優(yōu)的情況下各個(gè)特征的權(quán)重為特征的重要性。

使用案例
我們的開源項(xiàng)目AutoX把GRN計(jì)算特征重要性以及特征選擇的函數(shù)進(jìn)行了封裝:
使用GRN_feature_selection進(jìn)行特征重要性計(jì)算:
from autox.autox_competition.feature_selection import GRN_feature_selectionGRN_feature_selection = GRN_feature_selection()column_definition = {"cat":['investment_id'],"num":[]}for i in range(300):column_definition['num'].append((f'f_{i}'))GRN_feature_selection.fit(train[used], train[target], column_definition)# Train[used]是完成了包含所有特征的dataframe# Train[target]是標(biāo)簽列# column_definition中指定了特征的類型(特征屬于類別型變量還是連續(xù)型變量)
查看所有特征的重要性:
GRN_feature_selection.feature2weight
選擇top_k重要性的特征:
train_select = GRN_feature_selection.transform(train[used], top_k=20)test_select = GRN_feature_selection.transform(test[used], top_k=20)

完整案例地址
https://www.kaggle.com/code/hengwdai/grn-featureselection-autox
開源項(xiàng)目地址
https://github.com/4paradigm/AutoX
參考資料
Lim B, Arik S O, Loeff N, et al. Temporal fusion transformers for interpretable multi-horizon time series forecasting[J]. arXiv preprint arXiv:1912.09363, 2019.
往期精彩回顧
適合初學(xué)者入門人工智能的路線及資料下載 (圖文+視頻)機(jī)器學(xué)習(xí)入門系列下載 中國大學(xué)慕課《機(jī)器學(xué)習(xí)》(黃海廣主講) 機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印 《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專輯 AI基礎(chǔ)下載 機(jī)器學(xué)習(xí)交流qq群955171419,加入微信群請(qǐng)掃碼:
