回爐重造:計(jì)算圖——深入理解深度學(xué)習(xí)框架細(xì)節(jié)
點(diǎn)擊藍(lán)字
?關(guān)注我們

前言
相信各位做算法的同學(xué)都很熟悉框架的使用,但未必很清楚了解我們跑模型的時(shí)候,框架內(nèi)部在做什么,比如怎么自動(dòng)求導(dǎo),反向傳播。這一系列細(xì)節(jié)雖然用戶不需要關(guān)注,但如果能深入理解,那會(huì)對(duì)整個(gè)框架底層更加熟悉。
從一道算法題開(kāi)始
有算法基礎(chǔ)的同學(xué),應(yīng)該都知道迪杰斯特拉的雙棧算術(shù)表達(dá)式求和這個(gè)經(jīng)典算法。他的原理是利用兩個(gè)棧分別存放運(yùn)算數(shù),操作。根據(jù)不同的情況彈出棧里的元素,并進(jìn)行運(yùn)算,我們可以具體看下圖

這里討論的是最簡(jiǎn)單的情況,我們根據(jù)操作符的優(yōu)先級(jí),以及括號(hào)的種類(lèi)(左括號(hào)和右括號(hào)),分別進(jìn)行運(yùn)算,然后得到最終結(jié)果。
神經(jīng)網(wǎng)絡(luò)里怎么做?
在神經(jīng)網(wǎng)絡(luò)里,我們把數(shù)據(jù)和權(quán)重都以矩陣運(yùn)算的形式來(lái)計(jì)算得到最終的結(jié)果。舉個(gè)常見(jiàn)的例子,在全連接層中,我們都是使用矩陣乘法matmul來(lái)進(jìn)行運(yùn)算,形式如下

如圖,一個(gè)(2x3)的矩陣W和一個(gè)(3x2)的矩陣X運(yùn)算出來(lái)的結(jié)果Y1是(2x2) 那么Y可以被表示為
那后續(xù)還有一系列相關(guān)操作,比如我們可以假設(shè)
這一系列運(yùn)算,都是我們拿輸入X一層,一層的前向計(jì)算,因此這一個(gè)過(guò)程被稱為前向傳播
神經(jīng)網(wǎng)絡(luò)為了學(xué)習(xí)調(diào)節(jié)參數(shù),那就需要優(yōu)化,我們通過(guò)一個(gè)損失函數(shù)來(lái)衡量模型性能,然后使用梯度下降法對(duì)模型進(jìn)行優(yōu)化 原理如下(完整的可以參考我寫(xiě)的一篇深度學(xué)習(xí)里的優(yōu)化)

可以看到最后我們能讓loss值變小,這也能代表模型性能得到了優(yōu)化。那既然涉及到了梯度,就需要對(duì)里面的元素進(jìn)行求導(dǎo)了。那么應(yīng)該對(duì)誰(shuí)求呢, 也就是神經(jīng)網(wǎng)絡(luò)里的權(quán)重W1, W2, W3
可以觀察到,要想求各個(gè)權(quán)重,就需要從最后一層往前逐層推進(jìn)。求導(dǎo)得到各個(gè)權(quán)重對(duì)應(yīng)的梯度,這叫后向傳播。那既然算術(shù)表達(dá)式可以用雙棧來(lái)輕松的表達(dá)
對(duì)于神經(jīng)網(wǎng)絡(luò)里的運(yùn)算,需要前向傳播和后向傳播,有沒(méi)有什么好的數(shù)據(jù)結(jié)構(gòu)對(duì)其進(jìn)行抽象呢?有的,那就是我們需要說(shuō)的計(jì)算圖
計(jì)算圖
我們借用圖的結(jié)構(gòu)就能很好的表示整個(gè)前向和后向的過(guò)程。形式如下

我們?cè)賮?lái)看一個(gè)更具體的例子

(這幅圖摘自Paddle教程。
比如最后一項(xiàng)計(jì)算是
則在反向傳播中 650這一項(xiàng)對(duì)應(yīng)的梯度為1.1 1.1這一項(xiàng)對(duì)應(yīng)的梯度為650 以此類(lèi)推。
常見(jiàn)的反向傳播
卷積層的反向傳播
這里參考的是知乎一篇 Conv卷積層反向求導(dǎo) 我們寫(xiě)一個(gè)簡(jiǎn)單的1通道,3x3大小的卷積
import torchimport torch.nn as nnconv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=0, bias=False, stride=1)inputv = torch.range(1, 16).view(1, 1, 4, 4)print(inputv)out = conv(inputv)print(out)out = out.mean()out.backward()print(conv.weight.grad)
最后得到conv的梯度為
tensor([[[[ 3.5000, 4.5000, 5.5000],[ 7.5000, 8.5000, 9.5000],[11.5000, 12.5000, 13.5000]]]])
我們3x3 的卷積核形式如下

我們的數(shù)據(jù)為4x4矩陣

這里我們只關(guān)注卷積核左上角元素W1的求導(dǎo)過(guò)程 在stride=1,pad=0情況下,他的移動(dòng)過(guò)程是這樣的

白色是卷積核每次移動(dòng)覆蓋的區(qū)域,而藍(lán)色區(qū)塊,則是與權(quán)重W1經(jīng)過(guò)計(jì)算的位置
可以看到W1分別和1, 2, 5, 6這四個(gè)數(shù)字進(jìn)行計(jì)算 我們最后標(biāo)準(zhǔn)化一下
這就是權(quán)重W1對(duì)應(yīng)的梯度,以此類(lèi)推,我們可以得到9個(gè)梯度,分別對(duì)應(yīng)著3x3卷積核每個(gè)權(quán)重的梯度
卷積層求導(dǎo)的延申
其實(shí)卷積操作是可以被優(yōu)化成一個(gè)矩陣運(yùn)算的形式,該方法名為img2col 這里簡(jiǎn)單介紹下

藍(lán)色部分是我們的卷積核,我們可以攤平成1維向量,這里我們有兩個(gè)卷積核,就將2個(gè)1維向量進(jìn)行組合,得到一個(gè)核矩陣 同理,我們把輸入特征也攤平,得到輸入特征矩陣
這樣我們就可以將卷積操作,轉(zhuǎn)變成兩個(gè)矩陣相乘,最終得到輸出矩陣。而不需要用for循環(huán)嵌套,極大提升了運(yùn)算效率。
池化層的反向傳播
池化層本身并不存在參數(shù),但是不存在參數(shù)并不意味著不參加反向傳播過(guò)程。如果池化層不參加反向傳播過(guò)程,那么前面層的傳播也就中斷了。因此池化層需要將梯度傳遞到前面一層,而自身是不需要計(jì)算梯度優(yōu)化參數(shù)。
import torchimport numpy as npinputv = np.array([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16],])inputv = inputv.astype(np.float)inputv = torch.tensor(inputv,requires_grad=True).float()inputv = inputv.unsqueeze(0)inputv.retain_grad()print(inputv)pool = torch.nn.functional.max_pool2d(inputv, kernel_size=(3, 3), stride=1)print(pool)pool = torch.mean(pool)print(pool)pool.backward()print(inputv.grad)
注意這里我們打印的是input的梯度,因?yàn)槌鼗瘜幼陨聿痪邆涮荻?/p>tensor([[[0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.2500, 0.2500], [0.0000, 0.0000, 0.2500, 0.2500]]])
其中最大池化層是這樣做的

可以看到我們有4個(gè)元素進(jìn)行了最大池化,但為了保證傳播過(guò)程中,梯度總和不變,所以我們要?dú)w一化
也就是
因此最大元素那四個(gè)位置對(duì)應(yīng)的梯度是0.25 在平均池化過(guò)程中,操作有些許不一樣,具體可以參考 Pool反向傳播求導(dǎo)細(xì)節(jié)
靜態(tài)圖與動(dòng)態(tài)圖的區(qū)別
靜態(tài)圖
在tf1時(shí)代,其運(yùn)行機(jī)制是靜態(tài)圖,也就是符號(hào)式編程,tensorflow也是按照上面計(jì)算圖的思想,把整個(gè)運(yùn)算邏輯抽象成一張數(shù)據(jù)流圖

tensorflow提出了一個(gè)概念,叫PlaceHolder,即數(shù)據(jù)占位符。PlaceHolder只是有shape,dtype等基礎(chǔ)信息,沒(méi)有實(shí)際的數(shù)據(jù)。在網(wǎng)絡(luò)定義好后,需要對(duì)其進(jìn)行編譯。于是網(wǎng)絡(luò)就根據(jù)每一步驟的placeholder信息進(jìn)行編譯構(gòu)圖,構(gòu)圖過(guò)程中檢查是否有維度不匹配等錯(cuò)誤。待構(gòu)圖好后,再喂入數(shù)據(jù)給流圖。靜態(tài)圖只構(gòu)圖一次,運(yùn)行效率也會(huì)相對(duì)較高點(diǎn)。當(dāng)然現(xiàn)在的各大框架也在努力優(yōu)化動(dòng)態(tài)圖,縮小兩者之間效率差距。
動(dòng)態(tài)圖
動(dòng)態(tài)圖也稱為命令式編程,就像我們寫(xiě)代碼一樣,寫(xiě)到哪兒就執(zhí)行到哪兒。Pytorch便屬于這種,它與用戶更加友好,可以隨時(shí)在中間打印張量信息,方便我們進(jìn)行debug。
每一次讀取數(shù)據(jù)進(jìn)行計(jì)算,它都會(huì)重新進(jìn)行一次構(gòu)圖,并按照流程執(zhí)行下去。其特性更加適合研究者以及入門(mén)小白
兩者區(qū)別
靜態(tài)圖只構(gòu)圖一次 動(dòng)態(tài)圖每次運(yùn)行都重新構(gòu)圖 靜態(tài)圖能在編譯中做更好的優(yōu)化,但動(dòng)態(tài)圖的優(yōu)化也在不斷提升中

比如按動(dòng)態(tài)圖我們先乘后加,形式如左圖。在靜態(tài)圖里我們可以優(yōu)化到同一層級(jí),乘法和加法同時(shí)做到
總結(jié)
這篇文章講解了計(jì)算圖的提出,框架內(nèi)部常見(jiàn)算子的反向傳播方法,以及動(dòng)靜態(tài)圖的主要區(qū)別。限于篇幅,沒(méi)有講的特別深入,但讀完也基本可以對(duì)框架原理有了基本的了解~
推薦閱讀
深入理解圖注意力機(jī)制
深度學(xué)習(xí)中7種最優(yōu)化算法的可視化與理解
Gradient Centralization: 一行代碼加速訓(xùn)練并提升泛化能力 | ECCV 2020 Oral

