可逆神經(jīng)網(wǎng)絡(Invertible Neural Networks)詳細解析:讓神經(jīng)網(wǎng)絡更加輕量化

來源:PaperWeekly 本文約3600字,建議閱讀7分鐘
本文以可逆殘差網(wǎng)絡(The Reversible Residual Network: Backpropagation Without Storing Activations)作為基礎進行分析。
因為編碼和解碼使用相同的參數(shù),所以 model 是輕量級的??赡娴慕翟刖W(wǎng)絡 InvDN 只有 DANet 網(wǎng)絡參數(shù)量的 4.2%,但是 InvDN 的降噪性能更好。 由于可逆網(wǎng)絡是信息無損的,所以它能保留輸入數(shù)據(jù)的細節(jié)信息。 無論網(wǎng)絡的深度如何,可逆網(wǎng)絡都使用恒定的內存來計算梯度。
import?torch
from?torchvision?import?models
from?torchsummary?import?summary
device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
vgg?=?models.vgg16().to(device)
summary(vgg,?(3,?224,?224))
結果:
----------------------------------------------------------------
????????Layer?(type)???????????????Output?Shape?????????Param?#
================================================================
????????????Conv2d-1?????????[-1,?64,?224,?224]???????????1,792
??????????????ReLU-2?????????[-1,?64,?224,?224]???????????????0
????????????Conv2d-3?????????[-1,?64,?224,?224]??????????36,928
??????????????ReLU-4?????????[-1,?64,?224,?224]???????????????0
?????????MaxPool2d-5?????????[-1,?64,?112,?112]???????????????0
????????????Conv2d-6????????[-1,?128,?112,?112]??????????73,856
??????????????ReLU-7????????[-1,?128,?112,?112]???????????????0
????????????Conv2d-8????????[-1,?128,?112,?112]?????????147,584
??????????????ReLU-9????????[-1,?128,?112,?112]???????????????0
????????MaxPool2d-10??????????[-1,?128,?56,?56]???????????????0
???????????Conv2d-11??????????[-1,?256,?56,?56]?????????295,168
?????????????ReLU-12??????????[-1,?256,?56,?56]???????????????0
???????????Conv2d-13??????????[-1,?256,?56,?56]?????????590,080
?????????????ReLU-14??????????[-1,?256,?56,?56]???????????????0
???????????Conv2d-15??????????[-1,?256,?56,?56]?????????590,080
?????????????ReLU-16??????????[-1,?256,?56,?56]???????????????0
????????MaxPool2d-17??????????[-1,?256,?28,?28]???????????????0
???????????Conv2d-18??????????[-1,?512,?28,?28]???????1,180,160
?????????????ReLU-19??????????[-1,?512,?28,?28]???????????????0
???????????Conv2d-20??????????[-1,?512,?28,?28]???????2,359,808
?????????????ReLU-21??????????[-1,?512,?28,?28]???????????????0
???????????Conv2d-22??????????[-1,?512,?28,?28]???????2,359,808
?????????????ReLU-23??????????[-1,?512,?28,?28]???????????????0
????????MaxPool2d-24??????????[-1,?512,?14,?14]???????????????0
???????????Conv2d-25??????????[-1,?512,?14,?14]???????2,359,808
?????????????ReLU-26??????????[-1,?512,?14,?14]???????????????0
???????????Conv2d-27??????????[-1,?512,?14,?14]???????2,359,808
?????????????ReLU-28??????????[-1,?512,?14,?14]???????????????0
???????????Conv2d-29??????????[-1,?512,?14,?14]???????2,359,808
?????????????ReLU-30??????????[-1,?512,?14,?14]???????????????0
????????MaxPool2d-31????????????[-1,?512,?7,?7]???????????????0
???????????Linear-32?????????????????[-1,?4096]?????102,764,544
?????????????ReLU-33?????????????????[-1,?4096]???????????????0
??????????Dropout-34?????????????????[-1,?4096]???????????????0
???????????Linear-35?????????????????[-1,?4096]??????16,781,312
?????????????ReLU-36?????????????????[-1,?4096]???????????????0
??????????Dropout-37?????????????????[-1,?4096]???????????????0
???????????Linear-38?????????????????[-1,?1000]???????4,097,000
================================================================
Total?params:?138,357,544
Trainable?params:?138,357,544
Non-trainable?params:?0
----------------------------------------------------------------
Input?size?(MB):?0.57
Forward/backward?pass?size?(MB):?218.59
Params?size?(MB):?527.79
Estimated?Total?Size?(MB):?746.96
----------------------------------------------------------------
接下來我將先從可逆神經(jīng)網(wǎng)絡講起,然后是神經(jīng)網(wǎng)絡的反向傳播,最后是標準殘差網(wǎng)絡。對反向傳播算法和標準殘差網(wǎng)絡比較熟悉的小伙伴,可以只看第一節(jié):可逆神經(jīng)網(wǎng)絡。如果各位小伙伴不熟悉反向傳播算法和標準殘差網(wǎng)絡,建議先看第二節(jié):反向傳播(BP)算法和第三節(jié):殘差網(wǎng)絡(Residual Network)。本文1.2和1.3.4摘錄自 @阿亮。
可逆神經(jīng)網(wǎng)絡
可逆網(wǎng)絡具有的性質:
網(wǎng)絡的輸入、輸出的大小必須一致。 網(wǎng)絡的雅可比行列式不為 0。
1.1 什么是雅可比行列式?



1.2 雅可比行列式與神經(jīng)網(wǎng)絡的關系




其實這里跟矩陣運算很像,矩陣可逆的條件也是矩陣的雅可比行列式不為 0,雅可比矩陣可以理解為矩陣的一階導數(shù)。
假設可逆網(wǎng)絡的表達式為:


它的雅可比矩陣為:

1.3 可逆殘差網(wǎng)絡(Reversible Residual Network)

論文標題:
The Reversible Residual Network: Backpropagation Without Storing Activations
論文鏈接:
https://arxiv.org/abs/1707.04585
多倫多大學的 Aidan N.Gomez 和 Mengye Ren 提出了可逆殘差神經(jīng)網(wǎng)絡,當前層的激活結果可由下一層的結果計算得出,也就是如果我們知道網(wǎng)絡層最后的結果,就可以反推前面每一層的中間結果。這樣我們只需要存儲網(wǎng)絡的參數(shù)和最后一層的結果即可,激活結果的存儲與網(wǎng)絡的深度無關了,將大幅減少顯存占用。令人驚訝的是,實驗結果顯示,可逆殘差網(wǎng)絡的表現(xiàn)并沒有顯著下降,與之前的標準殘差網(wǎng)絡實驗結果基本旗鼓相當。

公式表示:



1.3.2 不用存儲激活結果的反向傳播
為了更好地計算反向傳播的步驟,我們修改一下上述正向計算和逆向計算的公式:


1.3.3 計算開銷
一個 N 個連接的神經(jīng)網(wǎng)絡,正向計算的理論加乘開銷為 N,反向傳播求導的理論加乘開銷為 2N(反向求導包含復合函數(shù)求導連乘),而可逆網(wǎng)絡多一步需要反向計算輸入值的操作,所以理論計算開銷為 4N,比普通網(wǎng)絡開銷約多出 33% 左右。但是在實際操作中,正向和反向的計算開銷在 GPU 上差不多,可以都理解為 N。那么這樣的話,普通網(wǎng)絡的整體計算開銷為 2N,可逆網(wǎng)絡的整體開銷為 3N,也就是多出了約 50%。
1.3.4 雅可比行列式的計算













因為是對偶的形式,所以這里的行列式也為 1.
反向傳播(BP)算法

x1,x2,x3:表示 3 個輸入層節(jié)點。 :表示從 t-1 層到 t 層的權重參數(shù),j 表示 t 層的第 j 個節(jié)點,i 表示 t-1 層的第 i 個節(jié)點。 :表示 t 層的第 i 個激活后輸出結果。 g(x):表示激活函數(shù)。
隱藏層(網(wǎng)絡的第二層)

輸出層(網(wǎng)絡的最后一層)

以平方誤差來計算反向傳播的過程,代價函數(shù)表示如下:



l=2,3 表示第幾層,j 表示某一層的第幾個節(jié)點。替換表示后如下:


從上述公式可以看出,如果神經(jīng)單元誤差 δ 可以求出來,那么總誤差對每一層的權重 w 和偏置 b 的偏導數(shù)就可以求出來,接下來就可以利用梯度下降法來優(yōu)化參數(shù)了。
求解每一層的 δ:
輸出層

隱藏層


殘差網(wǎng)絡(Residual Network)
殘差網(wǎng)絡主要可以解決兩個問題(其結構如下圖):
1)梯度消失問題;
2)網(wǎng)絡退化問題。





