AI 框架基礎(chǔ)技術(shù)之自動求導(dǎo)機(jī)制 (Autograd)
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
AI編輯:我是小將
本文作者:OpenMMLab @小P家的 900420
https://zhuanlan.zhihu.com/p/347385418
本文已由原作者授權(quán)轉(zhuǎn)載
0 前言
可以把神經(jīng)網(wǎng)絡(luò)看作一個復(fù)合數(shù)學(xué)函數(shù),網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計決定了多個基礎(chǔ)函數(shù)如何復(fù)合成復(fù)合函數(shù),網(wǎng)絡(luò)的訓(xùn)練過程確定了復(fù)合函數(shù)的所有參數(shù)。為了獲得一個“優(yōu)秀”的函數(shù),訓(xùn)練過程中會基于給定的數(shù)據(jù)集合,對該函數(shù)參數(shù)進(jìn)行多次迭代修正,重復(fù)如下幾個步驟:
前向傳播
計算損失
反向傳播(計算參數(shù)的梯度)
更新參數(shù)
這里第 3 步反向傳播過程會根據(jù)輸出的梯度推導(dǎo)出參數(shù)的梯度,第 4 步會根據(jù)這些梯度更新神經(jīng)網(wǎng)絡(luò)的參數(shù),這兩步是神經(jīng)網(wǎng)絡(luò)可以不斷優(yōu)化的核心。反向傳播過程中需要計算出所有參數(shù)的梯度,這當(dāng)然可以由網(wǎng)絡(luò)設(shè)計者自己計算并且通過硬編碼的方式實現(xiàn),但是網(wǎng)絡(luò)模型復(fù)雜多樣,為每個網(wǎng)絡(luò)都硬編碼去實現(xiàn)參數(shù)梯度計算將會耗費(fèi)大量精力。因此,AI 框架中往往會實現(xiàn)自動求導(dǎo)機(jī)制,以自動完成參數(shù)的梯度計算,并在每個 iter 中自動更新梯度,使得網(wǎng)絡(luò)設(shè)計者可以將注意力放到網(wǎng)絡(luò)結(jié)構(gòu)的設(shè)計中,而不必關(guān)心梯度是如何計算的。
本文的內(nèi)容基于我們自研的 AI 框架 SenseParrots,介紹框架自動求導(dǎo)的實現(xiàn)方式。本次分享將分為如下兩部分:
自動求導(dǎo)機(jī)制介紹
SenseParrots 自動求導(dǎo)實現(xiàn)
1 自動求導(dǎo)機(jī)制介紹
從數(shù)學(xué)層面上看求導(dǎo)這個問題,有很多種分類方法:按照求導(dǎo)結(jié)果來分,可以分為數(shù)值求導(dǎo)和符號求導(dǎo);按照求導(dǎo)順序來分,可以分為 forward mode 和 reverse mode;按照導(dǎo)數(shù)階數(shù)來分,可以分為一階導(dǎo)和高階導(dǎo)。在 AI 框架中實現(xiàn)自動求導(dǎo),最終目標(biāo)是拿到數(shù)值導(dǎo)數(shù),這里有兩種方式:第一種是直接進(jìn)行數(shù)值導(dǎo)數(shù)的計算;第二種是先求出符號導(dǎo)數(shù),再把數(shù)值帶入進(jìn)去。基于這個思路,目前主流 AI 框架中有兩種完全不同的自動求導(dǎo)機(jī)制:
基于對偶圖變換的自動求導(dǎo)機(jī)制
基于 reverse mode 的自動求導(dǎo)機(jī)制
1.1 基于對偶圖的自動求導(dǎo)機(jī)制
基于對偶圖的自動求導(dǎo)機(jī)制的實現(xiàn)思路是,首先通過一些模型解析手段獲得目標(biāo)函數(shù)對應(yīng)的前向計算圖,然后遍歷前向計算圖,使用計算圖中每一個前向算子節(jié)點(diǎn)對應(yīng)的反向算子節(jié)點(diǎn)構(gòu)造出反向計算圖,進(jìn)而實現(xiàn)自動求導(dǎo)。這里獲得的反向計算圖相當(dāng)于目標(biāo)函數(shù)符號導(dǎo)數(shù)結(jié)果,與原函數(shù)無差別的,可以將反向計算圖也用一個函數(shù)表示,傳入不同的參數(shù)進(jìn)行正常的調(diào)用。TVM 中基于對偶圖實現(xiàn)了一套自動求導(dǎo)機(jī)制,這里給出一段代碼示例:
s = (5, 10, 5)
t = relay.TensorType((5, 10, 5))
x = relay.var("x", t)
y = relay.var("y", t)
z = x + y
fwd_func = run_infer_type(relay.Function([x, y], z))
bwd_func = run_infer_type(gradient(fwd_func))
x_data = np.random.rand(*s).astype(t.dtype)
y_data = np.random.rand(*s).astype(t.dtype)
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data)1.2 基于 reverse mode 的自動求導(dǎo)機(jī)制
基于對偶圖的自動求導(dǎo)機(jī)制實現(xiàn)思路清晰,且有一些優(yōu)勢:1、只需要實現(xiàn)一次符號倒數(shù)的求解,后續(xù)只需要用不同的數(shù)值多次調(diào)用就可以得到目標(biāo)數(shù)值導(dǎo)數(shù);2、高階導(dǎo)的實現(xiàn)方式非常明顯,只需要在求導(dǎo)結(jié)果函數(shù)上進(jìn)一步調(diào)用自動求導(dǎo)模塊。但是該方案對計算圖和算子節(jié)點(diǎn)定義有比較嚴(yán)格的要求,前向算子節(jié)點(diǎn)和反向算子節(jié)點(diǎn)基本上要一一對應(yīng);另一方面,該方案需要先完成前向計算圖的完整解析,才能開始反向計算圖的生成,整個過程具有滯后性,所以適用于基于靜態(tài)圖的 AI 框架。在基于動態(tài)圖的 AI 框架,如 PyTorch、SenseParrots 中,我們一般使用基于 reverse mode 的自動求導(dǎo)機(jī)制。
這里對 reverse mode 概念進(jìn)行詳細(xì)介紹。reverse mode,即依據(jù)[鏈?zhǔn)椒▌t]的反向模式,指在進(jìn)行梯度計算過程中,從最后一個節(jié)點(diǎn)開始,依次向前計算得到每個輸入的梯度?;?reverse mode 進(jìn)行梯度計算,可以有效地把各個節(jié)點(diǎn)的梯度計算解耦開,每次只需要關(guān)注計算圖中當(dāng)前節(jié)點(diǎn)的梯度計算。
基于reverse mode進(jìn)行梯度計算的過程可以分為三步,以下列復(fù)合函數(shù)計算為例:


1. 首先創(chuàng)建計算圖:

2. 然后計算前向傳播的值,即?
?。
3. 在進(jìn)行反向傳播時,基于給定的輸出?
?的梯度?
?,依次計算:



在基于動態(tài)圖的 AI 框架中,計算圖的創(chuàng)建發(fā)生在前向傳播過程中,于是基于 reverse mode 的自動求導(dǎo)機(jī)制,整體過程可以簡化為兩步:第一步是在前向傳播過程中構(gòu)建出計算圖,與基于對偶圖的自動求導(dǎo)機(jī)制的滯后性相反,這里在前向傳播過程中就可以構(gòu)造出的反向計算圖;第二步是基于輸出的梯度信息對輸入自動求導(dǎo)。更多的細(xì)節(jié)將在下一章節(jié)展開。
2 SenseParrots 自動求導(dǎo)實現(xiàn)
2.1 自動求導(dǎo)機(jī)制組件
SenseParrots 是一個基于動態(tài)圖的AI框架(在線編譯功能部分進(jìn)行了局部靜態(tài)化,并不影響自動求導(dǎo)的整體機(jī)制),自動求導(dǎo)機(jī)制采用上述的反向模式,整個自動求導(dǎo)機(jī)制主要依賴于以下三個部分:
DArray: 計算數(shù)據(jù)的數(shù)據(jù)結(jié)構(gòu), 可以想象成多維數(shù)組, 其中包含參與運(yùn)算的數(shù)據(jù)、其梯度及以其作為輸出的 GradFn。
Function: 一個基本的運(yùn)算單元,包括一個操作的正向計算函數(shù)及其反向計算函數(shù),每個計算過程對應(yīng)一個 Function。比如一個 ReLU 激活函數(shù)的 Function 包括如下兩部分
Class ReLU : Function {
DArray forward(const DArray& x) {
DArray y = ...; // ReLU正向計算過程
return y;
}
DArray backward(const DArray& dy) {
DArray dx = ...; // ReLU反向計算過程
return dx;
}
};
GradFn: 計算圖中的節(jié)點(diǎn),每個 Function 在執(zhí)行正向計算的時候會產(chǎn)生一個 GradFn 對象,保存了輸入和輸出的梯度信息的指針、Function 指針以確定反向計算要調(diào)用的函數(shù)、后繼 GradFn 節(jié)點(diǎn)指針,該對象保存在該 Function 前向計算的輸出 DArray 中。
PS: SenseParrots 完全兼容 PyTorch,也為了方便大家理解,后文中涉及到的代碼采用 Torch 接口。
2.2 自動求導(dǎo)機(jī)制的控制選項
DArray 的 requires_grad 屬性標(biāo)志該數(shù)據(jù)是否需要求梯度。requires_grad 設(shè)置為 True 時計算梯度,并且會生成 LeafGradFn(GradFn 的子類)來標(biāo)識該節(jié)點(diǎn)為葉子節(jié)點(diǎn),計算圖的構(gòu)造依賴于輸入的 requires_grad 屬性;
框架是否開啟求導(dǎo)。默認(rèn)情況下框架是開啟求導(dǎo)的,也提供了顯示的開關(guān)求導(dǎo)的接口:torch.no_grad()、torch.enable_grad(),在框架關(guān)閉求導(dǎo)功能的情況下,不會構(gòu)造計算圖。
2.3 前向傳播過程中構(gòu)造計算圖
SenseParrots 在前向計算過程中,會根據(jù)用戶定義的計算過程,依次調(diào)用每個 Function 中的前向計算函數(shù)來完成計算。在調(diào)用每一個 Function 時,首先判斷輸入中是否有需要求梯度的:
如果輸入都不需要求梯度,則不會構(gòu)造計算圖,直接調(diào)用函數(shù)計算得到輸出, 并將輸出的 requires_grad 設(shè)置為 False;
如果輸入中有需要求梯度的,則調(diào)用函數(shù)計算得到輸出, 并將輸出的 requires_grad 設(shè)置為 True, 同時會相應(yīng)生成一個 GradFn 對象,并完成如下關(guān)聯(lián)工作(“保存”都是以 shared_ptr 方式):
將該 Function 記錄進(jìn)該 GradFn 對象,以表明在反向求導(dǎo)時,用 GradFn 中記錄的 Function 的反向計算函數(shù)來進(jìn)行梯度計算;
將該 Function 前向計算函數(shù)的輸入 DArray 的梯度記錄進(jìn) GradFn 對象,將該 Function 前向計算函數(shù)的輸出 DArray 的梯度記錄進(jìn) GradFn 對象;
將該 Function 前向計算函數(shù)的輸入 DArray 中所記錄的 GradFn 記錄為 GradFn 的后繼節(jié)點(diǎn);
將該 GradFn 保存進(jìn) Function 前向計算函數(shù)的所有輸出當(dāng)中。
由最初的輸入數(shù)據(jù)(葉子節(jié)點(diǎn))開始,依次執(zhí)行 Function,便可以構(gòu)造得到一張完整的計算圖。下面舉例子介紹計算圖的構(gòu)造過程(框架默認(rèn)啟用求導(dǎo)功能的情況下):
import torch
x1 = torch.randn((2,3,4), requires_grad=True)
x2 = torch.randn((2,3,4), requires_grad=True)
x3 = torch.randn((2,3,4))
x4 = torch.randn((2,3,4))
y1 = x1 + x2
y2 = x3 + x4
z = y1 * y2
z += x2首先我們計算的輸入數(shù)據(jù)為 x1、x2、x3、x4,當(dāng)前計算圖中 x1、x2 需要計算梯度,已經(jīng)創(chuàng)建 LeafGradFn 節(jié)點(diǎn),而 x3、x4的 GradFn 都為空指針,因此,最初的計算圖中包含兩個節(jié)點(diǎn),即 x1、x2 的 LeafGF1、LeafGF2。

以 x1、x2 作為輸入,調(diào)用 "+" Function 的正向計算函數(shù),得到輸出 y1,因為 x1、x2 都需要計算梯度,設(shè)置 y1 的 requires_grad=True,同時生成 GradFn,GF1, 將 "+" Function 記錄到 GF1 中,將輸入 x1、x2 的梯度記錄到 GF1 中,將輸出 y 的梯度記錄在 GF1 中,將 x1、x2 的 GradFn 記錄為 GradFn 的后繼節(jié)點(diǎn),將 GF1 保存在 y1 中;當(dāng)前計算圖中有 3 個節(jié)點(diǎn):LeafGF1、LeafGF2、GF1。

以 x3、x4 作為輸入,調(diào)用 "+" Function 的正向計算函數(shù),得到輸出 y2, 因為 x3、x4 都不需要計算梯度,y2 的 requires_grad=False, 此時計算圖中仍然只有 3 個節(jié)點(diǎn):LeafGF1、LeafGF2、GF1。

以 y1、y2 作為輸入,調(diào)用 "*" Funtcion 的正向計算函數(shù),得到輸出 z,由于輸入 y1 需要計算梯度,設(shè)置 z 的 requires_grad=True,同時生成 GradFn GF2,并且完成相應(yīng)信息的關(guān)聯(lián),當(dāng)前計算圖中有 4 個節(jié)點(diǎn):LeafGF1、LeafGF2、GF1、GF2。

需要注意的是,最后一個計算 "+=" 是一個 inplace 的計算,即以 z、x2 為輸入,計算結(jié)果 z,在處理 inplace 計算時,仍然遵循同樣的 GradFN 構(gòu)造方式即可,同時構(gòu)造 GF3,將 "+=" Function、輸入 x1 梯度、z 梯度、輸出 z 梯度、后繼節(jié)點(diǎn) GF2、LeafGF1 記錄進(jìn) GF3,需要注意的是,這里將 z 中的 GradFn 更新為 GF3,而原來z中保存的 GF2 作為 GF3 的后繼節(jié)點(diǎn)了,此時計算圖中有 5 個節(jié)點(diǎn):LeafGF1、LeafGF2、GF1、GF2、GF3。

由此得到了完整的計算圖,并且完成了相關(guān)信息的關(guān)聯(lián),完整的計算圖如下:

2.4 基于輸出的梯度信息對輸入自動求導(dǎo)
z.backward(torch.ones_like(z))在基于動態(tài)圖的 AI 框架中,反向求導(dǎo)過程通常是由上述的.backward(梯度)函數(shù)觸發(fā)的。SenseParrots 的反向求導(dǎo)過程,首先根據(jù)給定的輸出梯度,更新最終輸出的梯度值;然后對計算圖中節(jié)點(diǎn)進(jìn)行拓?fù)渑判?,獲得滿足依賴關(guān)系的 GradFn 的執(zhí)行順序;依次執(zhí)行 GradFn 中所記錄 Function 的反向計算函數(shù),根據(jù)輸出的梯度,計算并更新輸入的梯度。
首先看一下上述例子,其中 x1 只與一個 GradFn 相關(guān),其梯度只會被計算一次,這種輸入只影響單個輸出的情況,是反向求導(dǎo)中最簡單的一種情況;x2 與兩個 GradFn 相關(guān),這是反向求導(dǎo)中,一個輸入影響多個直接輸出的情況,需要注意,輸入 x2 的梯度也會被計算兩次,在梯度更新時,需要將多次計算得到的梯度進(jìn)行累加;z 的計算涉及到 inplace 操作,我們在 2.3 的第 5 步中說明了該情況的處理。下面介紹上述例子的反向求導(dǎo)過程:
基于給定的 z 的梯度信息,更新z中的梯度值;
基于計算圖進(jìn)行拓?fù)渑判?,獲得 GradFn 的執(zhí)行隊列(一個可能的序列為:GF3 -> GF2 -> GF1 -> LeafGF1 -> LeafGF2);
開始反向求導(dǎo),首先執(zhí)行 GF3,GF3 是一個 inplace 操作,以 z 的梯度作為輸入,調(diào)用 "+=" Function 的反向計算函數(shù),計算并更新 z、x2 的梯度,此時執(zhí)行隊列為(GF2 -> GF1 -> LeafGF1 -> LeafGF2);

4. 執(zhí)行 GF2,以 GF3 計算之后的 z 的梯度作為輸入,調(diào)用 "*" Function的反向計算函數(shù),計算 y1、y2 的梯度, 更新 y1 的梯度,因為 y2 不需要求梯度,所以其梯度信息舍棄, 此時執(zhí)行隊列為(GF1 -> LeafGF1 -> LeafGF2);

5. 執(zhí)行 GF1,以 y1 的梯度作為輸入,調(diào)用 "+" Function 的反向計算函數(shù),計算 x1、x2 的梯度,更新 x1 的梯度,而 x2 的梯度信息需要在之前計算結(jié)果的基礎(chǔ)上累加,此時執(zhí)行隊列為(LeafGF1 -> LeafGF2);

6. 依次執(zhí)行 LeafGF1、LeafGF2。

7. 執(zhí)行隊列為空,反向求導(dǎo)過程結(jié)束,默認(rèn)情況下計算圖會被清空,非葉子節(jié)點(diǎn)的梯度信息清空。由此得到了需要的計算梯度。

推薦閱讀
谷歌提出Meta Pseudo Labels,刷新ImageNet上的SOTA!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
SWA:讓你的目標(biāo)檢測模型無痛漲點(diǎn)1% AP
CondInst:性能和速度均超越Mask RCNN的實例分割模型
mmdetection最小復(fù)刻版(十一):概率Anchor分配機(jī)制PAA深入分析
MMDetection新版本V2.7發(fā)布,支持DETR,還有YOLOV4在路上!
無需tricks,知識蒸餾提升ResNet50在ImageNet上準(zhǔn)確度至80%+
不妨試試MoCo,來替換ImageNet上pretrain模型!
mmdetection最小復(fù)刻版(七):anchor-base和anchor-free差異分析
mmdetection最小復(fù)刻版(四):獨(dú)家yolo轉(zhuǎn)化內(nèi)幕
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ? ??????? ??一個用心的公眾號
?

