<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【神經(jīng)網(wǎng)絡(luò)搜索】DARTS: Differentiable Architecture Search

          共 3971字,需瀏覽 8分鐘

           ·

          2021-03-03 09:23

          【GiantPandaCV】DARTS將離散的搜索空間松弛,從而可以用梯度的方式進(jìn)行優(yōu)化,從而求解神經(jīng)網(wǎng)絡(luò)搜索問題。本文首發(fā)于GiantPandaCV,未經(jīng)允許,不得轉(zhuǎn)載。

          1. 簡介

          此論文之前的NAS大部分都是使用強(qiáng)化學(xué)習(xí)或者進(jìn)化算法等在離散的搜索空間中找到最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)。而DARTS的出現(xiàn),開辟了一個(gè)新的分支,將離散的搜索空間進(jìn)行松弛,得到連續(xù)的搜索空間,進(jìn)而可以使用梯度優(yōu)化的方處理神經(jīng)網(wǎng)絡(luò)搜索問題。DARTS將NAS建模為一個(gè)兩級(jí)優(yōu)化問題(Bi-Level Optimization),通過使用Gradient Decent的方法進(jìn)行交替優(yōu)化,從而可以求解出最優(yōu)的網(wǎng)絡(luò)架構(gòu)。DARTS也屬于One-Shot NAS的方法,也就是先構(gòu)建一個(gè)超網(wǎng),然后從超網(wǎng)中得到最優(yōu)子網(wǎng)絡(luò)的方法。

          2. 貢獻(xiàn)

          DARTS文章一共有三個(gè)貢獻(xiàn):

          • 基于二級(jí)最優(yōu)化方法提出了一個(gè)全新的可微分的神經(jīng)網(wǎng)絡(luò)搜索方法。
          • 在CIFAR-10和PTB(NLP數(shù)據(jù)集)上都達(dá)到了非常好的結(jié)果。
          • 和之前的不可微分方式的網(wǎng)絡(luò)搜索相比,效率大幅度提升,可以在單個(gè)GPU上訓(xùn)練出一個(gè)滿意的模型。

          筆者這里補(bǔ)一張對(duì)比圖,來自之前筆者翻譯的一篇綜述:<NAS的挑戰(zhàn)和解決方案-一份全面的綜述>

          ImageNet上各種方法對(duì)比,DARTS屬于Gradient Optimization方法

          簡單一對(duì)比,DARTS開創(chuàng)的Gradient Optimization方法使用的GPU Days就可以看出結(jié)果非常驚人,與基于強(qiáng)化學(xué)習(xí)、進(jìn)化算法等相比,DARTS不愧是年輕人的第一個(gè)NAS模型。

          3. 方法

          DARTS采用的是Cell-Based網(wǎng)絡(luò)架構(gòu)搜索方法,也分為Normal Cell和Reduction Cell兩種,分別搜索完成以后會(huì)通過拼接的方式形成完整網(wǎng)絡(luò)。在DARTS中假設(shè)每個(gè)Cell都有兩個(gè)輸入,一個(gè)輸出。對(duì)于Convolution Cell來說,輸入的節(jié)點(diǎn)是前兩層的輸出;對(duì)于Recurrent Cell來說,輸入為當(dāng)前步和上一步的隱藏狀態(tài)。

          DARTS核心方法可以用下面這四個(gè)圖來講解。

          DARTS Overview

          (a) 圖是一個(gè)有向無環(huán)圖,并且每個(gè)后邊的節(jié)點(diǎn)都會(huì)與前邊的節(jié)點(diǎn)相連,比如節(jié)點(diǎn)3一定會(huì)和節(jié)點(diǎn)0,1,2都相連。這里的節(jié)點(diǎn)可以理解為特征圖;邊代表采用的操作,比如卷積、池化等。

          引入數(shù)學(xué)標(biāo)記:

           節(jié)點(diǎn)(特征圖)為: 代表第i個(gè)節(jié)點(diǎn)對(duì)應(yīng)的潛在特征表示(特征圖)。

           邊(操作)為:   代表從第i個(gè)節(jié)點(diǎn)到第j個(gè)節(jié)點(diǎn)采用的操作。

           每個(gè)節(jié)點(diǎn)的輸入輸出如下面公式表示,每個(gè)節(jié)點(diǎn)都會(huì)和之前的節(jié)點(diǎn)相連接,然后將結(jié)果通過求和的方式得到第j個(gè)節(jié)點(diǎn)的特征圖。

          $$x^{(j)}=\sum_{i<j} o^{(i,="" j)}\left(x^{(i)}\right)="" $$=""  所有的候選操作為 , 在DARTS中包括了3x3深度可分離卷積、5x5深度可分離卷積、3x3空洞卷積、5x5空洞卷積、3x3最大化池化、3x3平均池化,恒等,直連,共8個(gè)操作。


          (b) 圖是一個(gè)超網(wǎng),將每個(gè)邊都擴(kuò)展了8個(gè)操作,通過這種方式可以將離散的搜索空間松弛化。具體的操作根據(jù)如下公式:

          這個(gè)可以分為兩個(gè)部分理解,一個(gè)是代表操作,一個(gè)代表選擇概率 ,這是一個(gè)softmax構(gòu)成的概率,其中表示 第i個(gè)節(jié)點(diǎn)到第j個(gè)節(jié)點(diǎn)之間操作的權(quán)重,這也是之后需要搜索的網(wǎng)絡(luò)結(jié)構(gòu)參數(shù),會(huì)影響該操作的概率。即以下公式:

          左側(cè)代表當(dāng)前操作的概率,右側(cè)代表當(dāng)前操作的參數(shù)。

          (c)和(d)圖 是保留的邊,訓(xùn)練完成以后,從所有的邊中找到概率最大的邊,即以下公式:

          4. 數(shù)學(xué)推導(dǎo)

          DARTS將NAS問題看作二級(jí)最優(yōu)化問題,具體定義如下:

          代表當(dāng)前網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)的情況下,訓(xùn)練獲得的最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)。

          第一行代表:在驗(yàn)證數(shù)據(jù)集中,在特定網(wǎng)絡(luò)操作參數(shù)w下,通過訓(xùn)練獲得最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)

          第二行表示:在訓(xùn)練數(shù)據(jù)集中,在特定網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)下,通過訓(xùn)練獲得最優(yōu)的網(wǎng)絡(luò)操作參數(shù)

          條件:在結(jié)構(gòu)確定的情況下,獲得最優(yōu)的網(wǎng)絡(luò)操作權(quán)重

          ----- 結(jié)構(gòu)確定,訓(xùn)練好卷積核

          目標(biāo):在網(wǎng)絡(luò)操作權(quán)重確定的情況下,獲得最優(yōu)的結(jié)構(gòu)

          ----- 卷積核不動(dòng),選擇更好的結(jié)構(gòu)

          最簡單的方法是通過交替優(yōu)化參數(shù)和參數(shù), 來得到最優(yōu)的結(jié)果,偽代碼如下:

          DARTS偽代碼

          交替優(yōu)化的復(fù)雜度非常高,是, 這種復(fù)雜度不可能投入使用,所以要將復(fù)雜度進(jìn)行優(yōu)化,用復(fù)雜度低的公式近似目標(biāo)函數(shù)。

          這種近似方法在Meta Learning中經(jīng)常用到,詳見《Model-agnostic meta-learning for fast adaptation of deep networks》,也就是通過使用單個(gè)step的訓(xùn)練調(diào)整w,讓這個(gè)結(jié)果來近似

          然后對(duì)右側(cè)公式進(jìn)行推導(dǎo),得到梯度優(yōu)化以后的表達(dá)式:

          師兄提供

          這里求梯度使用的是鏈?zhǔn)椒▌t,回顧一下:

          則梯度計(jì)算為:

          或者

          師兄提供

          上述公式中Di代表對(duì)的第i項(xiàng)的偏導(dǎo)。


          手敲公式太痛苦了

          整理以后結(jié)果就是:

          計(jì)算結(jié)果

          減號(hào)后邊的是二次梯度,權(quán)重的梯度求解很麻煩,這里使用泰勒公式將二階轉(zhuǎn)為一階(h是一個(gè)很小的值)。

          泰勒公式復(fù)習(xí)

          利用最右下角的公式:

          ,, , , 代入可得(其中經(jīng)驗(yàn)上設(shè)置)

          其中

          這樣就可以將二次梯度轉(zhuǎn)化為多個(gè)一次梯度。到這里復(fù)雜度從優(yōu)化到

          一階近似: 當(dāng), 下面式子的二階倒數(shù)部分就消失了,這樣模型的梯度計(jì)算可能不夠準(zhǔn)確,效果雖然不如二階,但是計(jì)算速度快。只需要假設(shè)當(dāng)前的就是, 然后啟發(fā)式優(yōu)化驗(yàn)證集上的loss值即可。

          計(jì)算結(jié)果

          代碼實(shí)現(xiàn)上也有一定的區(qū)別,代碼將在下一篇講解。

          5. 實(shí)驗(yàn)設(shè)置

          這里我們暫且先關(guān)注CIFAR10上的實(shí)驗(yàn)效果。DARTS構(gòu)成網(wǎng)絡(luò)的方式之前已經(jīng)提到了,首先為每個(gè)單元內(nèi)布使用DARTS進(jìn)行搜索,通過在驗(yàn)證集上的表現(xiàn)決定最好的單元然后使用這些單元構(gòu)建更大的網(wǎng)絡(luò)架構(gòu),然后從頭開始訓(xùn)練,報(bào)告在測(cè)試集上的表現(xiàn)。

          CIFAR10上搜索操作有:

          • 3x3 & 5x5 可分離卷積
          • 3x3 & 5x5 空洞可分離卷積
          • 3x3 max & avg pooling
          • identiy
          • zero

          實(shí)驗(yàn)詳細(xì)設(shè)置:

          • 所有操作的stride=1, 為了保證他們空間分辨率,使用了padding。

          • 卷積操作使用的是ReLU-Conv-BN的順序,并且每個(gè)可分離卷積會(huì)被使用兩次。

          • 卷積單元包括了7個(gè)節(jié)點(diǎn),輸出節(jié)點(diǎn)為所有中間節(jié)點(diǎn)concate以后的結(jié)果。

          • 網(wǎng)絡(luò)整體深度的1/3和2/3處強(qiáng)制設(shè)置了reduction cell來降低空間分辨率。

          • 網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)是被所有normal cell共享的,同理是被所有reduction cell共享的。

          • 并沒有使用全局batch normalization, 使用的是batch-specific statistic batch normalization

          • CIFAR10一半的訓(xùn)練集作為驗(yàn)證集。

          • 8個(gè)單元的消亡了使用DARTS訓(xùn)練50個(gè)epoch, batch size設(shè)置為64, 初始通道個(gè)數(shù)為16。

          • 使用momentum SGD來優(yōu)化權(quán)重,初始學(xué)習(xí)率設(shè)置為0.025,momentum 0.9 weight decay為0.0004.

          • 網(wǎng)絡(luò)架構(gòu)參數(shù) 使用0作為初始化,使用Adam優(yōu)化器來優(yōu)化參數(shù),初始學(xué)習(xí)率設(shè)置為0.0004,momentum為(0.5,0.999)weight decay=0.001。

          CIFAR10上搜索結(jié)果和其他算法對(duì)比

          可以看到,搜索結(jié)果最終是優(yōu)于AmoebaNet-A和NASNet-A。具體搜索得到的Normal Cell和Reduction Cell可視化如下:

          Normal Cell & Reduction Cell for CIFAR10

          網(wǎng)絡(luò)評(píng)價(jià)

          網(wǎng)絡(luò)優(yōu)化對(duì)初始化值是非常敏感的,為了確定最終的網(wǎng)絡(luò)結(jié)構(gòu),DARTS將使用隨機(jī)種子運(yùn)行四次,每次得到的Cell都會(huì)在訓(xùn)練集上從頭開始訓(xùn)練很短一段時(shí)間大概100 epochs , 然后根據(jù)驗(yàn)證集上得到的最優(yōu)結(jié)果決定最終的架構(gòu)。

          為了驗(yàn)證被選擇的架構(gòu):

          • 隨機(jī)初始化權(quán)重
          • 從頭開始訓(xùn)練
          • 報(bào)告測(cè)試集上的模型表現(xiàn)

          CIFAR10搜索的模型遷移到ImageNet更多細(xì)節(jié):

          • 20個(gè)單元的大型網(wǎng)絡(luò)使用了96的batch size, 訓(xùn)練了600個(gè)epoch
          • 初始通道個(gè)數(shù)由16修改為36,為了讓模型的參數(shù)和其他模型參數(shù)量相當(dāng)。
          • 其他參數(shù)設(shè)置和搜索過程中參數(shù)一樣
          • 使用了cutout的數(shù)據(jù)增強(qiáng)方法,以0.2的概率進(jìn)行path dropout
          • 使用了auxiliary tower(輔助頭,在這里施加loss, 提前進(jìn)行反向傳播,InceptionV3中提出)
          • 使用PyTorch在單個(gè)GPU上花費(fèi)1.5天時(shí)間訓(xùn)練完ImageNet,獨(dú)立訓(xùn)練10次作為最終的結(jié)果。
          CIFAR10上搜索結(jié)果

          使用二階優(yōu)化方法+cutout的數(shù)據(jù)增強(qiáng)方法,DARTS能達(dá)到約2.76的準(zhǔn)確率,筆者使用nni進(jìn)行了實(shí)驗(yàn),最終結(jié)果是2.6%的Test Error。

          nni上darts的實(shí)驗(yàn)結(jié)果

          6. 致謝&參考

          感謝師兄提供的資料,以及知乎上兩位大佬,他們文章鏈接如下:

          薰風(fēng)讀論文|DARTS—年輕人的第一個(gè)NAS模型 https://zhuanlan.zhihu.com/p/156832334

          【論文筆記】DARTS公式推導(dǎo) https://zhuanlan.zhihu.com/p/73037439


          瀏覽 60
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  一本无码在线播放 | 日本黄色电影免费看 | 91久久成人无码 | 日本中文字幕A√ | 青娱乐99999在线中文字幕 |