【神經(jīng)網(wǎng)絡(luò)搜索】1. NAS-RL(ICLR2017)
谷歌最早發(fā)表的有關(guān)NAS的文章,全稱Neural Architecture Search with Reinforcement Learning
文章鏈接:https://arxiv.org/pdf/1611.01578.pdf

神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索經(jīng)典范式是,首先通過(guò)controller以p概率采樣一個(gè)網(wǎng)絡(luò)結(jié)構(gòu),然后開(kāi)始訓(xùn)練網(wǎng)絡(luò)結(jié)構(gòu)得到準(zhǔn)確率R,根據(jù)準(zhǔn)確率R和概率p可以使用梯度上升的方法更新controller的參數(shù)。
在NAS-RL中,使用了Policy Gradient算法來(lái)訓(xùn)練controller(通常實(shí)現(xiàn)是一個(gè)RNN或者LSTM)。訓(xùn)練完采樣網(wǎng)絡(luò)后在驗(yàn)證集上得到的準(zhǔn)確率就是環(huán)境反饋的獎(jiǎng)勵(lì)值Reward,根據(jù)這個(gè)Reward可以通過(guò)梯度優(yōu)化的方法得到最優(yōu)的RNN和網(wǎng)絡(luò)結(jié)構(gòu)。
1.1 網(wǎng)絡(luò)結(jié)構(gòu)的表示
在神經(jīng)網(wǎng)絡(luò)搜索中,controller生成了一系列代表結(jié)構(gòu)的超參數(shù)(tokens)。

上圖展示了一個(gè)RNN生成超參數(shù)的詳細(xì)過(guò)程,每五個(gè)輸出結(jié)果組成一個(gè)Layer,每個(gè)Layer中
包含了一個(gè)卷積所需要的參數(shù),主要包含:
卷積核高 卷積核寬 Stride高 Stride寬 濾波器個(gè)數(shù)
如果想要加上類似ResNet的skip connection結(jié)構(gòu),可以引入Anchor Point進(jìn)行指向:

1.2 用REINFORCE進(jìn)行訓(xùn)練
Controller預(yù)測(cè)的一系列tokens可以被視為一系列Action , 根據(jù)token可以得到對(duì)應(yīng)的網(wǎng)絡(luò)結(jié)構(gòu),在訓(xùn)練集上訓(xùn)練生成的結(jié)構(gòu),在驗(yàn)證集上得到準(zhǔn)確率 R, 相當(dāng)于得到了獎(jiǎng)勵(lì)Reward,根據(jù)獎(jiǎng)勵(lì)值可以使用強(qiáng)化學(xué)習(xí)方法訓(xùn)練Controller。
將目標(biāo)函數(shù)如下:
目標(biāo)函數(shù)的意義是,在當(dāng)前一系列Action并且Controller的參數(shù)為的情況下,希望得到的獎(jiǎng)勵(lì)R的期望盡可能大。而R是不可微的,所以只能采用迭代的方式逼近最優(yōu)結(jié)果。
R的表達(dá)式:, 其中代表一系列Action。
對(duì)進(jìn)行求偏導(dǎo)得到以下結(jié)果:
在本問(wèn)題中,如果對(duì)進(jìn)行求導(dǎo),得到以下結(jié)果:
可以用以下式子進(jìn)行近似:
雖然上式是梯度的無(wú)偏估計(jì),但是方差比較大,所以添加一個(gè)baseline b:
1.3 并行訓(xùn)練

為了加速訓(xùn)練過(guò)程,采用了parameter-server機(jī)制,一共有S個(gè)參數(shù)服務(wù)器,保存的是K個(gè)Controller的復(fù)制,每個(gè)復(fù)制品會(huì)采樣m個(gè)不同的子結(jié)構(gòu),這樣可以同時(shí)進(jìn)行訓(xùn)練,然后每個(gè)Controller收集m個(gè)子結(jié)構(gòu)得到的梯度,然后將更新結(jié)果提交到參數(shù)服務(wù)器。
1.4 RNN的結(jié)構(gòu)生成
RNN和LSTM都是接收和作為輸入,得到輸出結(jié)果。可以將這個(gè)過(guò)程看作一個(gè)樹(shù),控制器RNN需要去標(biāo)記每個(gè)節(jié)點(diǎn)的具體方法,比如加法、乘法、激活函數(shù)等,來(lái)合并兩個(gè)輸入得到一個(gè)輸出結(jié)果。受LSTM的啟發(fā),引入表示記憶狀態(tài)。

對(duì)照?qǐng)D很容易理解,需要解釋的就是Cell Inject和Cell Indices,首先看最后的輸出0,這代表需要計(jì)算Tree Index 0輸出結(jié)果,具體計(jì)算方法是Cell Inject決定的;倒數(shù)第二個(gè)預(yù)測(cè)為1,這代表的輸出由Tree Index1的輸出決定。
1.5 實(shí)驗(yàn)結(jié)果

在訓(xùn)練了12800個(gè)結(jié)構(gòu)以后,找到了在驗(yàn)證集上最優(yōu)的結(jié)構(gòu)。然后使用grid search方法搜索學(xué)習(xí)率、weight decay、batchnorm epsilon和衰減學(xué)習(xí)率的epoch。
可以看出,結(jié)果上和人工設(shè)計(jì)的網(wǎng)絡(luò)架構(gòu)差距不是很大,另外一個(gè)特點(diǎn)是NAS-RL得到的網(wǎng)絡(luò)深度都很淺,否則搜索空間會(huì)過(guò)大,組合爆炸。
有問(wèn)題或者想加入交流群歡迎加筆者微信交流,請(qǐng)注明來(lái)意。
