AI 框架基礎(chǔ)技術(shù)之自動(dòng)求導(dǎo)機(jī)制 (Autograd)
0 前言
可以把神經(jīng)網(wǎng)絡(luò)看作一個(gè)復(fù)合數(shù)學(xué)函數(shù),網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計(jì)決定了多個(gè)基礎(chǔ)函數(shù)如何復(fù)合成復(fù)合函數(shù),網(wǎng)絡(luò)的訓(xùn)練過(guò)程確定了復(fù)合函數(shù)的所有參數(shù)。為了獲得一個(gè)“優(yōu)秀”的函數(shù),訓(xùn)練過(guò)程中會(huì)基于給定的數(shù)據(jù)集合,對(duì)該函數(shù)參數(shù)進(jìn)行多次迭代修正,重復(fù)如下幾個(gè)步驟:
前向傳播 計(jì)算損失 反向傳播(計(jì)算參數(shù)的梯度) 更新參數(shù)
這里第 3 步反向傳播過(guò)程會(huì)根據(jù)輸出的梯度推導(dǎo)出參數(shù)的梯度,第 4 步會(huì)根據(jù)這些梯度更新神經(jīng)網(wǎng)絡(luò)的參數(shù),這兩步是神經(jīng)網(wǎng)絡(luò)可以不斷優(yōu)化的核心。反向傳播過(guò)程中需要計(jì)算出所有參數(shù)的梯度,這當(dāng)然可以由網(wǎng)絡(luò)設(shè)計(jì)者自己計(jì)算并且通過(guò)硬編碼的方式實(shí)現(xiàn),但是網(wǎng)絡(luò)模型復(fù)雜多樣,為每個(gè)網(wǎng)絡(luò)都硬編碼去實(shí)現(xiàn)參數(shù)梯度計(jì)算將會(huì)耗費(fèi)大量精力。因此,AI 框架中往往會(huì)實(shí)現(xiàn)自動(dòng)求導(dǎo)機(jī)制,以自動(dòng)完成參數(shù)的梯度計(jì)算,并在每個(gè) iter 中自動(dòng)更新梯度,使得網(wǎng)絡(luò)設(shè)計(jì)者可以將注意力放到網(wǎng)絡(luò)結(jié)構(gòu)的設(shè)計(jì)中,而不必關(guān)心梯度是如何計(jì)算的。
本文的內(nèi)容基于我們自研的 AI 框架 SenseParrots,介紹框架自動(dòng)求導(dǎo)的實(shí)現(xiàn)方式。本次分享將分為如下兩部分:
自動(dòng)求導(dǎo)機(jī)制介紹 SenseParrots 自動(dòng)求導(dǎo)實(shí)現(xiàn)
1 自動(dòng)求導(dǎo)機(jī)制介紹
從數(shù)學(xué)層面上看求導(dǎo)這個(gè)問(wèn)題,有很多種分類方法:按照求導(dǎo)結(jié)果來(lái)分,可以分為數(shù)值求導(dǎo)和符號(hào)求導(dǎo);按照求導(dǎo)順序來(lái)分,可以分為 forward mode 和 reverse mode;按照導(dǎo)數(shù)階數(shù)來(lái)分,可以分為一階導(dǎo)和高階導(dǎo)。在 AI 框架中實(shí)現(xiàn)自動(dòng)求導(dǎo),最終目標(biāo)是拿到數(shù)值導(dǎo)數(shù),這里有兩種方式:第一種是直接進(jìn)行數(shù)值導(dǎo)數(shù)的計(jì)算;第二種是先求出符號(hào)導(dǎo)數(shù),再把數(shù)值帶入進(jìn)去?;谶@個(gè)思路,目前主流 AI 框架中有兩種完全不同的自動(dòng)求導(dǎo)機(jī)制:
基于對(duì)偶圖變換的自動(dòng)求導(dǎo)機(jī)制 基于 reverse mode 的自動(dòng)求導(dǎo)機(jī)制
1.1 基于對(duì)偶圖的自動(dòng)求導(dǎo)機(jī)制
基于對(duì)偶圖的自動(dòng)求導(dǎo)機(jī)制的實(shí)現(xiàn)思路是,首先通過(guò)一些模型解析手段獲得目標(biāo)函數(shù)對(duì)應(yīng)的前向計(jì)算圖,然后遍歷前向計(jì)算圖,使用計(jì)算圖中每一個(gè)前向算子節(jié)點(diǎn)對(duì)應(yīng)的反向算子節(jié)點(diǎn)構(gòu)造出反向計(jì)算圖,進(jìn)而實(shí)現(xiàn)自動(dòng)求導(dǎo)。這里獲得的反向計(jì)算圖相當(dāng)于目標(biāo)函數(shù)符號(hào)導(dǎo)數(shù)結(jié)果,與原函數(shù)無(wú)差別的,可以將反向計(jì)算圖也用一個(gè)函數(shù)表示,傳入不同的參數(shù)進(jìn)行正常的調(diào)用。TVM 中基于對(duì)偶圖實(shí)現(xiàn)了一套自動(dòng)求導(dǎo)機(jī)制,這里給出一段代碼示例:
s?=?(5,?10,?5)
t?=?relay.TensorType((5,?10,?5))
x?=?relay.var("x",?t)
y?=?relay.var("y",?t)
z?=?x?+?y
??
fwd_func?=?run_infer_type(relay.Function([x,?y],?z))
bwd_func?=?run_infer_type(gradient(fwd_func))
????
x_data?=?np.random.rand(*s).astype(t.dtype)
y_data?=?np.random.rand(*s).astype(t.dtype)
intrp?=?relay.create_executor(ctx=ctx,?target=target)
op_res,?(op_grad0,?op_grad1)?=?intrp.evaluate(bwd_func)(x_data,?y_data)
1.2 基于 reverse mode 的自動(dòng)求導(dǎo)機(jī)制
基于對(duì)偶圖的自動(dòng)求導(dǎo)機(jī)制實(shí)現(xiàn)思路清晰,且有一些優(yōu)勢(shì):1、只需要實(shí)現(xiàn)一次符號(hào)倒數(shù)的求解,后續(xù)只需要用不同的數(shù)值多次調(diào)用就可以得到目標(biāo)數(shù)值導(dǎo)數(shù);2、高階導(dǎo)的實(shí)現(xiàn)方式非常明顯,只需要在求導(dǎo)結(jié)果函數(shù)上進(jìn)一步調(diào)用自動(dòng)求導(dǎo)模塊。但是該方案對(duì)計(jì)算圖和算子節(jié)點(diǎn)定義有比較嚴(yán)格的要求,前向算子節(jié)點(diǎn)和反向算子節(jié)點(diǎn)基本上要一一對(duì)應(yīng);另一方面,該方案需要先完成前向計(jì)算圖的完整解析,才能開始反向計(jì)算圖的生成,整個(gè)過(guò)程具有滯后性,所以適用于基于靜態(tài)圖的 AI 框架。在基于動(dòng)態(tài)圖的 AI 框架,如 PyTorch、SenseParrots 中,我們一般使用基于 reverse mode 的自動(dòng)求導(dǎo)機(jī)制。
這里對(duì) reverse mode 概念進(jìn)行詳細(xì)介紹。reverse mode,即依據(jù)[鏈?zhǔn)椒▌t]的反向模式,指在進(jìn)行梯度計(jì)算過(guò)程中,從最后一個(gè)節(jié)點(diǎn)開始,依次向前計(jì)算得到每個(gè)輸入的梯度?;?reverse mode 進(jìn)行梯度計(jì)算,可以有效地把各個(gè)節(jié)點(diǎn)的梯度計(jì)算解耦開,每次只需要關(guān)注計(jì)算圖中當(dāng)前節(jié)點(diǎn)的梯度計(jì)算。
基于reverse mode進(jìn)行梯度計(jì)算的過(guò)程可以分為三步,以下列復(fù)合函數(shù)計(jì)算為例:


首先創(chuàng)建計(jì)算圖:

然后計(jì)算前向傳播的值,即

在進(jìn)行反向傳播時(shí),基于給定的輸出z的梯度dz,依次計(jì)算:




在基于動(dòng)態(tài)圖的 AI 框架中,計(jì)算圖的創(chuàng)建發(fā)生在前向傳播過(guò)程中,于是基于 reverse mode 的自動(dòng)求導(dǎo)機(jī)制,整體過(guò)程可以簡(jiǎn)化為兩步:第一步是在前向傳播過(guò)程中構(gòu)建出計(jì)算圖,與基于對(duì)偶圖的自動(dòng)求導(dǎo)機(jī)制的滯后性相反,這里在前向傳播過(guò)程中就可以構(gòu)造出的反向計(jì)算圖;第二步是基于輸出的梯度信息對(duì)輸入自動(dòng)求導(dǎo)。更多的細(xì)節(jié)將在下一章節(jié)展開。
2 SenseParrots 自動(dòng)求導(dǎo)實(shí)現(xiàn)
2.1 自動(dòng)求導(dǎo)機(jī)制組件
SenseParrots 是一個(gè)基于動(dòng)態(tài)圖的AI框架(在線編譯功能部分進(jìn)行了局部靜態(tài)化,并不影響自動(dòng)求導(dǎo)的整體機(jī)制),自動(dòng)求導(dǎo)機(jī)制采用上述的反向模式,整個(gè)自動(dòng)求導(dǎo)機(jī)制主要依賴于以下三個(gè)部分:
DArray: 計(jì)算數(shù)據(jù)的數(shù)據(jù)結(jié)構(gòu), 可以想象成多維數(shù)組, 其中包含參與運(yùn)算的數(shù)據(jù)、其梯度及以其作為輸出的 GradFn。 Function: 一個(gè)基本的運(yùn)算單元,包括一個(gè)操作的正向計(jì)算函數(shù)及其反向計(jì)算函數(shù),每個(gè)計(jì)算過(guò)程對(duì)應(yīng)一個(gè) Function。比如一個(gè) ReLU 激活函數(shù)的 Function 包括如下兩部分
????Class?ReLU?:?Function?{?????
????????DArray?forward(const?DArray&?x)?{?
????????????DArray?y?=?...;?//?ReLU正向計(jì)算過(guò)程?????????
????????????return?y;??????
????????}
?????????DArray?backward(const?DArray&?dy)?{
?????????????DArray?dx?=?...;?//?ReLU反向計(jì)算過(guò)程
?????????????return?dx;
?????????}
?????};?
GradFn: 計(jì)算圖中的節(jié)點(diǎn),每個(gè) Function 在執(zhí)行正向計(jì)算的時(shí)候會(huì)產(chǎn)生一個(gè) GradFn 對(duì)象,保存了輸入和輸出的梯度信息的指針、Function 指針以確定反向計(jì)算要調(diào)用的函數(shù)、后繼 GradFn 節(jié)點(diǎn)指針,該對(duì)象保存在該 Function 前向計(jì)算的輸出 DArray 中。
PS: SenseParrots 完全兼容 PyTorch,也為了方便大家理解,后文中涉及到的代碼采用 Torch 接口。
2.2 自動(dòng)求導(dǎo)機(jī)制的控制選項(xiàng)
DArray 的 requires_grad 屬性標(biāo)志該數(shù)據(jù)是否需要求梯度。requires_grad 設(shè)置為 True 時(shí)計(jì)算梯度,并且會(huì)生成 LeafGradFn(GradFn 的子類)來(lái)標(biāo)識(shí)該節(jié)點(diǎn)為葉子節(jié)點(diǎn),計(jì)算圖的構(gòu)造依賴于輸入的 requires_grad 屬性; 框架是否開啟求導(dǎo)。默認(rèn)情況下框架是開啟求導(dǎo)的,也提供了顯示的開關(guān)求導(dǎo)的接口:torch.no_grad()、torch.enable_grad(),在框架關(guān)閉求導(dǎo)功能的情況下,不會(huì)構(gòu)造計(jì)算圖。
2.3 前向傳播過(guò)程中構(gòu)造計(jì)算圖
SenseParrots 在前向計(jì)算過(guò)程中,會(huì)根據(jù)用戶定義的計(jì)算過(guò)程,依次調(diào)用每個(gè) Function 中的前向計(jì)算函數(shù)來(lái)完成計(jì)算。在調(diào)用每一個(gè) Function 時(shí),首先判斷輸入中是否有需要求梯度的:
如果輸入都不需要求梯度,則不會(huì)構(gòu)造計(jì)算圖,直接調(diào)用函數(shù)計(jì)算得到輸出, 并將輸出的 requires_grad 設(shè)置為 False;
如果輸入中有需要求梯度的,則調(diào)用函數(shù)計(jì)算得到輸出, 并將輸出的 requires_grad 設(shè)置為 True, 同時(shí)會(huì)相應(yīng)生成一個(gè) GradFn 對(duì)象,并完成如下關(guān)聯(lián)工作(“保存”都是以 shared_ptr 方式):
將該 Function 記錄進(jìn)該 GradFn 對(duì)象,以表明在反向求導(dǎo)時(shí),用 GradFn 中記錄的 Function 的反向計(jì)算函數(shù)來(lái)進(jìn)行梯度計(jì)算; 將該 Function 前向計(jì)算函數(shù)的輸入 DArray 的梯度記錄進(jìn) GradFn 對(duì)象,將該 Function 前向計(jì)算函數(shù)的輸出 DArray 的梯度記錄進(jìn) GradFn 對(duì)象; 將該 Function 前向計(jì)算函數(shù)的輸入 DArray 中所記錄的 GradFn 記錄為 GradFn 的后繼節(jié)點(diǎn); 將該 GradFn 保存進(jìn) Function 前向計(jì)算函數(shù)的所有輸出當(dāng)中。
由最初的輸入數(shù)據(jù)(葉子節(jié)點(diǎn))開始,依次執(zhí)行 Function,便可以構(gòu)造得到一張完整的計(jì)算圖。下面舉例子介紹計(jì)算圖的構(gòu)造過(guò)程(框架默認(rèn)啟用求導(dǎo)功能的情況下):
import?torch
x1?=?torch.randn((2,3,4),?requires_grad=True)
x2?=?torch.randn((2,3,4),?requires_grad=True)
x3?=?torch.randn((2,3,4))
x4?=?torch.randn((2,3,4))
????
y1?=?x1?+?x2
y2?=?x3?+?x4
z?=?y1?*?y2
z?+=?x2
首先我們計(jì)算的輸入數(shù)據(jù)為 x1、x2、x3、x4,當(dāng)前計(jì)算圖中 x1、x2 需要計(jì)算梯度,已經(jīng)創(chuàng)建 LeafGradFn 節(jié)點(diǎn),而 x3、x4的 GradFn 都為空指針,因此,最初的計(jì)算圖中包含兩個(gè)節(jié)點(diǎn),即 x1、x2 的 LeafGF1、LeafGF2。

以 x1、x2 作為輸入,調(diào)用 "+" Function 的正向計(jì)算函數(shù),得到輸出 y1,因?yàn)?x1、x2 都需要計(jì)算梯度,設(shè)置 y1 的 requires_grad=True,同時(shí)生成 GradFn,GF1, 將 "+" Function 記錄到 GF1 中,將輸入 x1、x2 的梯度記錄到 GF1 中,將輸出 y 的梯度記錄在 GF1 中,將 x1、x2 的 GradFn 記錄為 GradFn 的后繼節(jié)點(diǎn),將 GF1 保存在 y1 中;當(dāng)前計(jì)算圖中有 3 個(gè)節(jié)點(diǎn):LeafGF1、LeafGF2、GF1。

以 x3、x4 作為輸入,調(diào)用 "+" Function 的正向計(jì)算函數(shù),得到輸出 y2, 因?yàn)?x3、x4 都不需要計(jì)算梯度,y2 的 requires_grad=False, 此時(shí)計(jì)算圖中仍然只有 3 個(gè)節(jié)點(diǎn):LeafGF1、LeafGF2、GF1。

以 y1、y2 作為輸入,調(diào)用 "*" Funtcion 的正向計(jì)算函數(shù),得到輸出 z,由于輸入 y1 需要計(jì)算梯度,設(shè)置 z 的 requires_grad=True,同時(shí)生成 GradFn GF2,并且完成相應(yīng)信息的關(guān)聯(lián),當(dāng)前計(jì)算圖中有 4 個(gè)節(jié)點(diǎn):LeafGF1、LeafGF2、GF1、GF2。

需要注意的是,最后一個(gè)計(jì)算 "+=" 是一個(gè) inplace 的計(jì)算,即以 z、x2 為輸入,計(jì)算結(jié)果 z,在處理 inplace 計(jì)算時(shí),仍然遵循同樣的 GradFN 構(gòu)造方式即可,同時(shí)構(gòu)造 GF3,將 "+=" Function、輸入 x1 梯度、z 梯度、輸出 z 梯度、后繼節(jié)點(diǎn) GF2、LeafGF1 記錄進(jìn) GF3,需要注意的是,這里將 z 中的 GradFn 更新為 GF3,而原來(lái)z中保存的 GF2 作為 GF3 的后繼節(jié)點(diǎn)了,此時(shí)計(jì)算圖中有 5 個(gè)節(jié)點(diǎn):LeafGF1、LeafGF2、GF1、GF2、GF3。

由此得到了完整的計(jì)算圖,并且完成了相關(guān)信息的關(guān)聯(lián),完整的計(jì)算圖如下:

2.4 基于輸出的梯度信息對(duì)輸入自動(dòng)求導(dǎo)
????z.backward(torch.ones_like(z))
在基于動(dòng)態(tài)圖的 AI 框架中,反向求導(dǎo)過(guò)程通常是由上述的.backward(梯度)函數(shù)觸發(fā)的。SenseParrots 的反向求導(dǎo)過(guò)程,首先根據(jù)給定的輸出梯度,更新最終輸出的梯度值;然后對(duì)計(jì)算圖中節(jié)點(diǎn)進(jìn)行拓?fù)渑判颍@得滿足依賴關(guān)系的 GradFn 的執(zhí)行順序;依次執(zhí)行 GradFn 中所記錄 Function 的反向計(jì)算函數(shù),根據(jù)輸出的梯度,計(jì)算并更新輸入的梯度。
首先看一下上述例子,其中 x1 只與一個(gè) GradFn 相關(guān),其梯度只會(huì)被計(jì)算一次,這種輸入只影響單個(gè)輸出的情況,是反向求導(dǎo)中最簡(jiǎn)單的一種情況;x2 與兩個(gè) GradFn 相關(guān),這是反向求導(dǎo)中,一個(gè)輸入影響多個(gè)直接輸出的情況,需要注意,輸入 x2 的梯度也會(huì)被計(jì)算兩次,在梯度更新時(shí),需要將多次計(jì)算得到的梯度進(jìn)行累加;z 的計(jì)算涉及到 inplace 操作,我們?cè)?2.3 的第 5 步中說(shuō)明了該情況的處理。下面介紹上述例子的反向求導(dǎo)過(guò)程:
基于給定的 z 的梯度信息,更新z中的梯度值; 基于計(jì)算圖進(jìn)行拓?fù)渑判?,獲得 GradFn 的執(zhí)行隊(duì)列(一個(gè)可能的序列為:GF3 -> GF2 -> GF1 -> LeafGF1 -> LeafGF2); 開始反向求導(dǎo),首先執(zhí)行 GF3,GF3 是一個(gè) inplace 操作,以 z 的梯度作為輸入,調(diào)用 "+=" Function 的反向計(jì)算函數(shù),計(jì)算并更新 z、x2 的梯度,此時(shí)執(zhí)行隊(duì)列為(GF2 -> GF1 -> LeafGF1 -> LeafGF2);

執(zhí)行 GF2,以 GF3 計(jì)算之后的 z 的梯度作為輸入,調(diào)用 "*" Function的反向計(jì)算函數(shù),計(jì)算 y1、y2 的梯度, 更新 y1 的梯度,因?yàn)?y2 不需要求梯度,所以其梯度信息舍棄, 此時(shí)執(zhí)行隊(duì)列為(GF1 -> LeafGF1 -> LeafGF2);

執(zhí)行 GF1,以 y1 的梯度作為輸入,調(diào)用 "+" Function 的反向計(jì)算函數(shù),計(jì)算 x1、x2 的梯度,更新 x1 的梯度,而 x2 的梯度信息需要在之前計(jì)算結(jié)果的基礎(chǔ)上累加,此時(shí)執(zhí)行隊(duì)列為(LeafGF1 -> LeafGF2);

依次執(zhí)行 LeafGF1、LeafGF2。

執(zhí)行隊(duì)列為空,反向求導(dǎo)過(guò)程結(jié)束,默認(rèn)情況下計(jì)算圖會(huì)被清空,非葉子節(jié)點(diǎn)的梯度信息清空。由此得到了需要的計(jì)算梯度。
- The End -
長(zhǎng)按二維碼關(guān)注我們
本公眾號(hào)專注:
1. 技術(shù)分享;
2.?學(xué)術(shù)交流;
3.?資料共享。
歡迎關(guān)注我們,一起成長(zhǎng)!
