<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          可逆神經(jīng)網(wǎng)絡(Invertible Neural Networks)詳細解析:讓神經(jīng)網(wǎng)絡更加輕量化

          共 7696字,需瀏覽 16分鐘

           ·

          2022-04-24 21:08


          來源:PaperWeekly

          本文約3600字,建議閱讀7分鐘

          本文以可逆殘差網(wǎng)絡(The Reversible Residual Network: Backpropagation Without Storing Activations)作為基礎進行分析。

          為什么要用可逆網(wǎng)絡呢?

          1. 因為編碼和解碼使用相同的參數(shù),所以 model 是輕量級的??赡娴慕翟刖W(wǎng)絡 InvDN 只有 DANet 網(wǎng)絡參數(shù)量的 4.2%,但是 InvDN 的降噪性能更好。
          2. 由于可逆網(wǎng)絡是信息無損的,所以它能保留輸入數(shù)據(jù)的細節(jié)信息。
          3. 無論網(wǎng)絡的深度如何,可逆網(wǎng)絡都使用恒定的內存來計算梯度。

          其中最主要目的就是為了減少內存的消耗,當前所有的神經(jīng)網(wǎng)絡都采用反向傳播的方式來訓練,反向傳播算法需要存儲網(wǎng)絡的中間結果來計算梯度,而且其對內存的消耗與網(wǎng)絡單元數(shù)成正比。這也就意味著,網(wǎng)絡越深越廣,對內存的消耗越大,這將成為很多應用的瓶頸。

          下面是 Pytorch summary 的結果,F(xiàn)orward/backward pass size(MB): 218.59 就是需要保存的中間變量大小,可以看出這部分占據(jù)了很大部分顯存(隨著網(wǎng)絡深度的增加,中間變量占據(jù)顯存量會一直增加,resnet152(size=224)的中間變量更是占據(jù)總共內存的 606.6÷836.79≈0.725 )。如果不存儲中間層結果,那么就可以大幅減少 GPU 的顯存占用,有助于訓練更深更廣的網(wǎng)絡。

          import?torch
          from?torchvision?import?models
          from?torchsummary?import?summary

          device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
          vgg?=?models.vgg16().to(device)

          summary(vgg,?(3,?224,?224))


          結果:


          ----------------------------------------------------------------
          ????????Layer?(type)???????????????Output?Shape?????????Param?#
          ================================================================
          ????????????Conv2d-1?????????[-1,?64,?224,?224]???????????1,792
          ??????????????ReLU-2?????????[-1,?64,?224,?224]???????????????0
          ????????????Conv2d-3?????????[-1,?64,?224,?224]??????????36,928
          ??????????????ReLU-4?????????[-1,?64,?224,?224]???????????????0
          ?????????MaxPool2d-5?????????[-1,?64,?112,?112]???????????????0
          ????????????Conv2d-6????????[-1,?128,?112,?112]??????????73,856
          ??????????????ReLU-7????????[-1,?128,?112,?112]???????????????0
          ????????????Conv2d-8????????[-1,?128,?112,?112]?????????147,584
          ??????????????ReLU-9????????[-1,?128,?112,?112]???????????????0
          ????????MaxPool2d-10??????????[-1,?128,?56,?56]???????????????0
          ???????????Conv2d-11??????????[-1,?256,?56,?56]?????????295,168
          ?????????????ReLU-12??????????[-1,?256,?56,?56]???????????????0
          ???????????Conv2d-13??????????[-1,?256,?56,?56]?????????590,080
          ?????????????ReLU-14??????????[-1,?256,?56,?56]???????????????0
          ???????????Conv2d-15??????????[-1,?256,?56,?56]?????????590,080
          ?????????????ReLU-16??????????[-1,?256,?56,?56]???????????????0
          ????????MaxPool2d-17??????????[-1,?256,?28,?28]???????????????0
          ???????????Conv2d-18??????????[-1,?512,?28,?28]???????1,180,160
          ?????????????ReLU-19??????????[-1,?512,?28,?28]???????????????0
          ???????????Conv2d-20??????????[-1,?512,?28,?28]???????2,359,808
          ?????????????ReLU-21??????????[-1,?512,?28,?28]???????????????0
          ???????????Conv2d-22??????????[-1,?512,?28,?28]???????2,359,808
          ?????????????ReLU-23??????????[-1,?512,?28,?28]???????????????0
          ????????MaxPool2d-24??????????[-1,?512,?14,?14]???????????????0
          ???????????Conv2d-25??????????[-1,?512,?14,?14]???????2,359,808
          ?????????????ReLU-26??????????[-1,?512,?14,?14]???????????????0
          ???????????Conv2d-27??????????[-1,?512,?14,?14]???????2,359,808
          ?????????????ReLU-28??????????[-1,?512,?14,?14]???????????????0
          ???????????Conv2d-29??????????[-1,?512,?14,?14]???????2,359,808
          ?????????????ReLU-30??????????[-1,?512,?14,?14]???????????????0
          ????????MaxPool2d-31????????????[-1,?512,?7,?7]???????????????0
          ???????????Linear-32?????????????????[-1,?4096]?????102,764,544
          ?????????????ReLU-33?????????????????[-1,?4096]???????????????0
          ??????????Dropout-34?????????????????[-1,?4096]???????????????0
          ???????????Linear-35?????????????????[-1,?4096]??????16,781,312
          ?????????????ReLU-36?????????????????[-1,?4096]???????????????0
          ??????????Dropout-37?????????????????[-1,?4096]???????????????0
          ???????????Linear-38?????????????????[-1,?1000]???????4,097,000
          ================================================================
          Total?params:?138,357,544
          Trainable?params:?138,357,544
          Non-trainable?params:?0
          ----------------------------------------------------------------
          Input?size?(MB):?0.57
          Forward/backward?pass?size?(MB):?218.59
          Params?size?(MB):?527.79
          Estimated?Total?Size?(MB):?746.96
          ----------------------------------------------------------------

          接下來我將先從可逆神經(jīng)網(wǎng)絡講起,然后是神經(jīng)網(wǎng)絡的反向傳播,最后是標準殘差網(wǎng)絡。對反向傳播算法和標準殘差網(wǎng)絡比較熟悉的小伙伴,可以只看第一節(jié):可逆神經(jīng)網(wǎng)絡。如果各位小伙伴不熟悉反向傳播算法和標準殘差網(wǎng)絡,建議先看第二節(jié):反向傳播(BP)算法和第三節(jié):殘差網(wǎng)絡(Residual Network)。本文1.2和1.3.4摘錄自 @阿亮。


          可逆神經(jīng)網(wǎng)絡


          可逆網(wǎng)絡具有的性質:


          1. 網(wǎng)絡的輸入、輸出的大小必須一致。
          2. 網(wǎng)絡的雅可比行列式不為 0。


          1.1 什么是雅可比行列式?


          雅可比行列式通常稱為雅可比式(Jacobian),它是以 n 個 n 元函數(shù)的偏導數(shù)為元素的行列式 。事實上,在函數(shù)都連續(xù)可微(即偏導數(shù)都連續(xù))的前提之下,它就是函數(shù)組的微分形式下的系數(shù)矩陣(即雅可比矩陣)的行列式。若因變量對自變量連續(xù)可微,而自變量對新變量連續(xù)可微,則因變量也對新變量連續(xù)可微。這可用行列式的乘法法則和偏導數(shù)的連鎖法則直接驗證。也類似于導數(shù)的連鎖法則。偏導數(shù)的連鎖法則也有類似的公式;這常用于重積分的計算中。




          1.2 雅可比行列式與神經(jīng)網(wǎng)絡的關系


          為什么神經(jīng)網(wǎng)絡會與雅可比行列式有關系?這里我借用李宏毅老師的 ppt(12-14頁)。想看視頻的可以到 b 站上看。





          簡單的來講就是?,他們的分布之間的關系就變?yōu)?,又因為有?,所以??這個網(wǎng)絡的雅可比行列式不為 0 才行。


          順便提一下,flow-based Model 優(yōu)化的損失函數(shù)如下:


          其實這里跟矩陣運算很像,矩陣可逆的條件也是矩陣的雅可比行列式不為 0,雅可比矩陣可以理解為矩陣的一階導數(shù)。


          假設可逆網(wǎng)絡的表達式為:



          它的雅可比矩陣為:



          其行列式為 1。

          1.3 可逆殘差網(wǎng)絡(Reversible Residual Network)



          論文標題:

          The Reversible Residual Network: Backpropagation Without Storing Activations

          論文鏈接:

          https://arxiv.org/abs/1707.04585


          多倫多大學的 Aidan N.Gomez 和 Mengye Ren 提出了可逆殘差神經(jīng)網(wǎng)絡,當前層的激活結果可由下一層的結果計算得出,也就是如果我們知道網(wǎng)絡層最后的結果,就可以反推前面每一層的中間結果。這樣我們只需要存儲網(wǎng)絡的參數(shù)和最后一層的結果即可,激活結果的存儲與網(wǎng)絡的深度無關了,將大幅減少顯存占用。令人驚訝的是,實驗結果顯示,可逆殘差網(wǎng)絡的表現(xiàn)并沒有顯著下降,與之前的標準殘差網(wǎng)絡實驗結果基本旗鼓相當。


          1.3.1 可逆塊結構

          可逆神經(jīng)網(wǎng)絡將每一層分割成兩部分,分別為??和?,每一個可逆塊的輸入是?,輸出是?。其結構如下:

          正向計算圖示:



          公式表示:



          逆向計算圖示:


          公式表示:


          其中 F 和 G 都是相似的殘差函數(shù),參考上圖殘差網(wǎng)絡??赡鎵K的跨距只能為 1,也就是說可逆塊必須一個接一個連接,中間不能采用其它網(wǎng)絡形式銜接,否則的話就會丟失信息,并且無法可逆計算了,這點與殘差塊不一樣。如果一定要采取跟殘差塊相似的結構,也就是中間一部分采用普通網(wǎng)絡形式銜接,那中間這部分的激活結果就必須顯式的存起來。


          1.3.2 不用存儲激活結果的反向傳播


          為了更好地計算反向傳播的步驟,我們修改一下上述正向計算和逆向計算的公式:



          盡管??和??的值是相同的,但是兩個變量在圖中卻代表不同的節(jié)點,所以在反向傳播中它們的總體導數(shù)是不一樣的。?的導數(shù)包含通過??產(chǎn)生的間接影響,而??的導數(shù)卻不受??的任何影響。

          在反向傳播計算流程中,先給出最后一層的激活值??和誤差傳播的總體導數(shù)?,然后要計算出其輸入值??和對應的導數(shù)?,以及殘差函數(shù) F 和 G 中權重參數(shù)的總體導數(shù),求解步驟如下:



          1.3.3 計算開銷


          一個 N 個連接的神經(jīng)網(wǎng)絡,正向計算的理論加乘開銷為 N,反向傳播求導的理論加乘開銷為 2N(反向求導包含復合函數(shù)求導連乘),而可逆網(wǎng)絡多一步需要反向計算輸入值的操作,所以理論計算開銷為 4N,比普通網(wǎng)絡開銷約多出 33% 左右。但是在實際操作中,正向和反向的計算開銷在 GPU 上差不多,可以都理解為 N。那么這樣的話,普通網(wǎng)絡的整體計算開銷為 2N,可逆網(wǎng)絡的整體開銷為 3N,也就是多出了約 50%。


          1.3.4 雅可比行列式的計算



          其編碼公式如下:


          其解碼公式如下:


          為了計算雅可比矩陣,我們更直觀的寫成下面的編碼公式:


          它的雅可比矩陣為:


          其實上面這個雅可比行列式也是1,因為這里?,它們的系數(shù)是一樣的。

          有另外一種解釋方式就是把這種對偶的形式切成兩半:




          其行列式為 1.



          因為是對偶的形式,所以這里的行列式也為 1.


          因為?,所以其行列式也為 1。


          反向傳播(BP)算法



          上圖中符號的含義:

          • x1,x2,x3:表示 3 個輸入層節(jié)點。
          • :表示從 t-1 層到 t 層的權重參數(shù),j 表示 t 層的第 j 個節(jié)點,i 表示 t-1 層的第 i 個節(jié)點。
          • :表示 t 層的第 i 個激活后輸出結果。
          • g(x):表示激活函數(shù)。

          正向傳播計算過程:

          • 隱藏層(網(wǎng)絡的第二層)


          • 輸出層(網(wǎng)絡的最后一層)


          反向傳播計算過程:

          以單個樣本為例,假設輸入向量是 [x1,x2,x3],目標輸出值是 [y1,y2],代價函數(shù)用 L 表示。反向傳播的總體原理就是根據(jù)總體輸出誤差,反向傳播回網(wǎng)絡,通過計算每一層節(jié)點的梯度,利用梯度下降法原理,更新每一層的網(wǎng)絡權重 w 和偏置 b,這也是網(wǎng)絡學習的過程。誤差反向傳播的優(yōu)點就是可以把繁雜的導數(shù)計算以數(shù)列遞推的形式來表示, 簡化了計算過程。


          以平方誤差來計算反向傳播的過程,代價函數(shù)表示如下:



          根據(jù)導數(shù)的鏈式法則反向求解隱藏 -> 輸出層、輸入層 -> 隱藏層的權重表示:


          引入新的誤差求導表示形式,稱為神經(jīng)單元誤差:



          l=2,3 表示第幾層,j 表示某一層的第幾個節(jié)點。替換表示后如下:



          所以我們可以歸納出一般的計算公式:


          從上述公式可以看出,如果神經(jīng)單元誤差 δ 可以求出來,那么總誤差對每一層的權重 w 和偏置 b 的偏導數(shù)就可以求出來,接下來就可以利用梯度下降法來優(yōu)化參數(shù)了。


          求解每一層的 δ:


          • 輸出層


          • 隱藏層


          也就是說,我們根據(jù)輸出層的神經(jīng)誤差單元 δ 就可以直接求出隱藏層的神經(jīng)誤差單元,進而省去了隱藏層的繁雜的求導過程,我們可以得出更一般的計算過程:


          從而得出 l 層神經(jīng)單元誤差和 l+1 層神經(jīng)單元誤差的關系。這就是誤差反向傳播算法,只要求出輸出層的神經(jīng)單元誤差,其它層的神經(jīng)單元誤差就不需要計算偏導數(shù)了,而可以直接通過上述公式得出。

          殘差網(wǎng)絡(Residual Network)


          殘差網(wǎng)絡主要可以解決兩個問題(其結構如下圖):


          1)梯度消失問題;

          2)網(wǎng)絡退化問題。

          上述結構就是一個兩層網(wǎng)絡組成的殘差塊,殘差塊可以由 2、3 層甚至更多層組成,但是如果是一層的,就變成線性變換了,沒什么意義了。上述圖可以寫成公式如下:


          所以在第二層進入激活函數(shù)ReLU之 前 F(x)+x 組成新的輸入,也叫恒等映射。

          恒等映射就是在這個殘差塊輸入是 x 的情況下輸出依然是 x,這樣其目標就是學習讓 F(X)=0。

          這里有一個問題哈,為什么要額外加一個 x 呢,而不是讓模型直接學習 F(x)=x?

          因為讓 F(x)=0 比較容易,初始化參數(shù) W 非常小接近 0,就可以讓輸出接近 0,同時輸出如果是負數(shù),經(jīng)過第一層 Relu 后輸出依然 0,都能使得最后的 F(x)=0,也就是有多種情況都可以使得 F(x)=0;但是讓 F(x)=x 確實非常難的,因為參數(shù)都必須剛剛好才能使得最后輸出為 x。

          恒等映射有什么作用?

          恒等映射就可以解決網(wǎng)絡退化的問題,當網(wǎng)絡層數(shù)越來越深的時候,網(wǎng)絡的精度卻在下降,也就是說網(wǎng)絡自身存在一個最優(yōu)的層度結構,太深太淺都能使得模型精度下降。有了恒等映射存在,網(wǎng)絡就能夠自己學習到哪些層是冗余的,就可以無損通過這些層,理論上講再深的網(wǎng)絡都不影響其精度,解決了網(wǎng)絡退化問題。

          為什么可以解決梯度消失問題呢?

          以兩個殘差塊的結構實例圖來分析,其中每個殘差塊有 2 層神經(jīng)網(wǎng)絡組成,如下圖:


          假設激活函數(shù) ReLU 用 g(x) 函數(shù)來表示,樣本實例是 [x1,y1],即輸入是 x1,目標值是 y1,損失函數(shù)還是采用平方損失函數(shù),則每一層的計算如下:


          下面我們對第一個殘差塊的權重參數(shù)求導,根據(jù)鏈式求導法則,公式如下:


          我們可以看到求導公式中多了一個+1項,這就將原來的鏈式求導中的連乘變成了連加狀態(tài),可以有效避免梯度消失了。

          參考文獻

          [1]PPT
          ?https://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/FLOW%20(v7).pdf
          [2] 神經(jīng)網(wǎng)絡的可逆形式?
          https://zhuanlan.zhihu.com/p/268242678
          [3] 大幅減少GPU顯存占用:
          可逆殘差網(wǎng)絡(The Reversible Residual Network)?
          https://www.cnblogs.com/gczr/p/12181354.html
          [4] 雅可比行列式?
          https://baike.baidu.com/item/雅可比行列式/4709261?fr=aladdin
          [5] The Reversible Residual Network:?
          Backpropagation Without Storing Activations
          [6] pytorch-summary?
          https://github.com/sksq96/pytorch-summary

          編輯:王菁
          校對:楊學俊





          瀏覽 28
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  青娱乐在线观看网址 | 婷婷97五月天 | 日日干天天干视频 | 国产一级黄片视频在线观看 | 91丨国产丨熟女 熟女 |