從零實(shí)現(xiàn)深度學(xué)習(xí)框架(七)優(yōu)化反向傳播相關(guān)代碼

引言
本著“凡我不能創(chuàng)造的,我就不能理解”的思想,本系列文章會(huì)基于純Python以及NumPy從零創(chuàng)建自己的深度學(xué)習(xí)框架,該框架類似PyTorch能實(shí)現(xiàn)自動(dòng)求導(dǎo)。
要深入理解深度學(xué)習(xí),從零開(kāi)始創(chuàng)建的經(jīng)驗(yàn)非常重要,從自己可以理解的角度出發(fā),盡量不適用外部完備的框架前提下,實(shí)現(xiàn)我們想要的模型。本系列文章的宗旨就是通過(guò)這樣的過(guò)程,讓大家切實(shí)掌握深度學(xué)習(xí)底層實(shí)現(xiàn),而不是僅做一個(gè)調(diào)包俠。
在前面的文章中,我們實(shí)現(xiàn)了反向傳播的模式。并實(shí)現(xiàn)了加法和乘法的計(jì)算圖。但是這種實(shí)現(xiàn)方式有一些弊端,本文就來(lái)優(yōu)化實(shí)現(xiàn)反向傳播模式的代碼,同時(shí)修復(fù)加法和乘法計(jì)算圖實(shí)現(xiàn)的問(wèn)題。
優(yōu)化反向傳播代碼
在上篇文章中,我們將_Function與Tensor放到了同一個(gè)文件中,這不符合單一職責(zé)模式。同時(shí)在實(shí)現(xiàn)Tensor的加法和乘法時(shí),我們需要手動(dòng)添加很多代碼,這也不優(yōu)雅。
def?__add__(self,?other):
????ctx?=?Add(self,?ensure_tensor(other))
????return?ctx.apply(ctx,?self,?ensure_tensor(other))
def?__mul__(self,?other):
????ctx?=?Mul(self,?ensure_tensor(other))
????return?ctx.apply(ctx,?self,?ensure_tensor(other))
首先,我們把與_Function相關(guān)的代碼移動(dòng)到新文件ops.py中:
from?typing?import?Any
import?numpy?as?np
from?core.tensor?import?Tensor
'''
ops.py保存所有運(yùn)算操作相關(guān)的類
'''
class?_Function:
????def?__init__(self,?*tensors:?"Tensor")?->?None:
????????#?該操作所依賴的所有輸入
????????self.depends_on?=?[t?for?t?in?tensors]
????????#?保存需要在backward()中使用的Tensor或其他對(duì)象(如Shape)
????????self.saved_tensors?=?[]
????def?__new__(cls,?*args,?**kwargs):
????????'''__new__是靜態(tài)方法,當(dāng)該類被實(shí)例化時(shí)調(diào)用'''
????????#?把以下方法轉(zhuǎn)換為靜態(tài)方法,我們可以通過(guò)類名直接調(diào)用
????????cls.forward?=?staticmethod(cls.forward)
????????cls.backward?=?staticmethod(cls.backward)
????????cls.apply?=?staticmethod(cls.apply)?#?新增
????????return?super().__new__(cls)
????def?save_for_backward(ctx,?*x:?Any)?->?None:
????????ctx.saved_tensors.extend(x)
????def?forward(ctx,?*args:?Any,?**kwargs:?Any)?->?np.ndarray:
????????'''前向傳播,進(jìn)行真正運(yùn)算的地方'''
????????raise?NotImplementedError("You?must?implement?the?forward?function?for?custom?Function.")
????def?backward(ctx,?grad:?Any)?->?Any:
????????'''實(shí)現(xiàn)反向傳播,計(jì)算梯度'''
????????raise?NotImplementedError("You?must?implement?the?backward?method?for?your?custom?Function?"
??????????????????????????????????"to?use?it?with?backward?mode?AD.")
????def?apply(fxn,?*xs:?"Tensor",?**kwargs)?->?"Tensor":
????????'''與PyTorch一樣,我們也不直接調(diào)用forward,而是調(diào)用此方法'''
????????#?先調(diào)用構(gòu)造函數(shù),傳入運(yùn)算依賴的Tensor
????????ctx?=?fxn(*xs)??#?調(diào)用到了_Function的__init__方法
????????#?[t.data for t in xs]遍歷Tensor中的data(np.ndarray)值,參與實(shí)際計(jì)算的都是NumPy的數(shù)組。
????????ret?=?Tensor(ctx.forward(ctx,?*[t.data?for?t?in?xs],?**kwargs),
?????????????????????requires_grad=any([t.requires_grad?for?t?in?xs]))
????????if?ret.requires_grad:
????????????ret._ctx?=?ctx
????????return?ret
class?Add(_Function):
????def?forward(ctx,?x:?np.ndarray,?y:?np.ndarray)?->?np.ndarray:
????????'''
????????實(shí)現(xiàn)?z?=?x?+?y?,我們這里的x和y都是Numpy數(shù)組,因此可能發(fā)生廣播,
????????在實(shí)現(xiàn)反向傳播是需要注意
????????'''
????????#?我們只要保存輸入各自的形狀即可
????????ctx.save_for_backward(x.shape,?y.shape)
????????#?進(jìn)行真正的運(yùn)算
????????return?x?+?y
????def?backward(ctx,?grad:?Any)?->?Any:
????????#?輸入有兩個(gè),都是需要計(jì)算梯度的,因此輸出也是兩個(gè)
????????return?grad,?grad
class?Mul(_Function):
????def?forward(ctx,?x:?np.ndarray,?y:?np.ndarray)?->?np.ndarray:
????????'''
????????實(shí)現(xiàn)?z?=?x?*?y
????????'''
????????#?乘法需要保存輸入x和y,用于反向傳播
????????ctx.save_for_backward(x,?y)
????????return?x?*?y
????def?backward(ctx,?grad:?Any)?->?Any:
????????x,?y?=?ctx.saved_tensors
????????#?分別返回?L/?x?和??L/?y
????????return?grad?*?y,?grad?*?x
同時(shí)修改apply方法為:
????def?apply(fxn,?*xs:?"Tensor",?**kwargs)?->?"Tensor":
????????'''與PyTorch一樣,我們也不直接調(diào)用forward,而是調(diào)用此方法'''
????????#?先調(diào)用構(gòu)造函數(shù),傳入運(yùn)算依賴的Tensor
????????ctx?=?fxn(*xs)??#?調(diào)用到了_Function的__init__方法
????????#?[t.data for t in xs]遍歷Tensor中的data(np.ndarray)值,參與實(shí)際計(jì)算的都是NumPy的數(shù)組。
????????ret?=?Tensor(ctx.forward(ctx,?*[t.data?for?t?in?xs],?**kwargs),
?????????????????????requires_grad=any([t.requires_grad?for?t?in?xs]))
????????if?ret.requires_grad:
????????????ret._ctx?=?ctx
????????return?ret
將該方法改為靜態(tài)方法,同時(shí)增加了ctx = fxn(*xs)這一句,在該方法實(shí)例化Function對(duì)象,傳入該運(yùn)算所依賴的輸入。
為了避免我們手動(dòng)添加__add__、__mul_這些實(shí)現(xiàn)。我們利用inspect類去自動(dòng)注冊(cè)相應(yīng)的魔法方法。
def?register(name,?fxn):
????print(f"register?{name}?:?{fxn}")
????def?dispatch(*xs,?**kwargs):
????????#?把所有的輸入都轉(zhuǎn)換為Tensor
????????xs?=?[ensure_tensor(x)?for?x?in?xs]
????????#?調(diào)用apply方法
????????return?fxn.apply(fxn,?*xs,?**kwargs)
????#?為Tensor添加屬性,名為name,值為dispatch函數(shù)引用
????setattr(Tensor,?name,?dispatch)
????#?這幾個(gè)方法都有__xx__,?__ixx__,?__rxx__?魔法方法
????if?name?in?["add",?"sub",?"mul",?"matmul"]:
????????setattr(Tensor,?f"__{name}__",?dispatch)
????????setattr(
????????????Tensor,?f"__i{name}__",?lambda?self,?x:?self.assign(dispatch(self,?x))
????????)??#?__i*__?代表原地操作
????????setattr(
????????????Tensor,?f"__r{name}__",?lambda?self,?x:?dispatch(x,?self)
????????)??#?__r*__?代表?other在操作符前,?self在操作符后
def?_register_ops(namespace):
????for?name,?cls?in?inspect.getmembers(namespace,?inspect.isclass):
????????if?name[0]?!=?"_"?and?name?!=?'Tensor':
????????????#?注冊(cè)所有_Function的子類
????????????register(name.lower(),?cls)
try:
????_register_ops(importlib.import_module("core.ops"))
except?ImportError?as?e:
????print(e)
此時(shí)當(dāng)我們初始化Tensor的時(shí)候,它會(huì)打印:
register?add?:?'core.ops.Add'>
register?mul?:?'core.ops.Mul'>
比如對(duì)于add,這段代碼會(huì)把__add__、__iadd__、__radd__和add綁定到其內(nèi)部的dispatch方法。
該方法主要做了兩件事,第一,統(tǒng)一把所有的輸入轉(zhuǎn)換為Tensor;第二,調(diào)用apply靜態(tài)方法。
優(yōu)化完了之后,我們得試一下還能正常使用么。
但是,這次博主不想寫一個(gè)main方法了,而是寫一些測(cè)試用例。并且,以后所有的代碼提交都走PR,利用github的action機(jī)制,只有測(cè)試通過(guò)的PR,才能合入主分支。
編寫測(cè)試用例
用一種比較簡(jiǎn)單的方法,就是創(chuàng)建以test開(kāi)頭的文件,同時(shí)里面的函數(shù)也是以test開(kāi)頭,idea會(huì)自動(dòng)識(shí)別為測(cè)試用例,如下圖所示:

我們分別測(cè)試標(biāo)量的加法、同shape向量的加法以及廣播情況下向量的加法。
from?core.tensor?import?Tensor
import?numpy?as?np
def?test_simple_add():
????x?=?Tensor(1,?requires_grad=True)
????y?=?2
????z?=?x?+?y
????z.backward()
????assert?x.grad.data?==?1.0
def?test_array_add():
????x?=?Tensor([1,?2,?3],?requires_grad=True)
????y?=?Tensor([4,?5,?6],?requires_grad=True)
????z?=?x?+?y
????assert?z.data.tolist()?==?[5.,?7.,?9.]
????#?如果
????z.backward([1,?1,?1])
????assert?x.grad.data.tolist()?==?[1,?1,?1]
????assert?y.grad.data.tolist()?==?[1,?1,?1]
????x?+=?1
????assert?x.grad?is?None
????assert?x.data.tolist()?==?[2,?3,?4]
def?test_broadcast_add():
????"""
????測(cè)試當(dāng)發(fā)生廣播時(shí),我們的代碼還能表現(xiàn)正常嗎。
????對(duì)于?z?=?x?+?y
????如果x.shape == y.shape,那么就像上面的例子一樣,沒(méi)什么問(wèn)題。
????如果x.shape?==?(2,3)??y.shape?==?(3,)?那么,根據(jù)廣播,先會(huì)在y左邊插入一個(gè)維度1,變成?->?y.shape?==?(1,3)
????????接著,在第0個(gè)維度上進(jìn)行復(fù)制,使得新的維度?y.shape?==?(2,3)
????這樣的話,對(duì)x求梯度時(shí),梯度要和x的shape保持一致;對(duì)y求梯度時(shí),也要和y的shape保持一致。
????"""
????x?=?Tensor(np.random.randn(2,?3),?requires_grad=True)??#?(2,3)
????y?=?Tensor(np.random.randn(3),?requires_grad=True)??#?(3,)
????z?=?x?+?y??#?(2,3)
????z.backward(Tensor(np.ones_like(x.data)))??#?grad.shape?==?z.shape
????assert?x.grad.data.tolist()?==?np.ones_like(x.data).tolist()
????assert?y.grad.data.tolist?==?[2,?2]
分別執(zhí)行每一個(gè)測(cè)試用例,第一個(gè)沒(méi)有問(wèn)題:
test_add.py::test_simple_add?PASSED??????????????????????????????????????[100%]
第二個(gè)測(cè)試方法報(bào)錯(cuò)了:
>???????????grads?=?t._ctx.backward(t._ctx,?t.grad.data)
E???????????AttributeError:?'list'?object?has?no?attribute?'data'
../../core/tensor.py:177:?AttributeError
==============================?1?failed?in?0.38s?===============================
哦,我們要確保backward()方法傳入的grad為Tensor對(duì)象。
所以,我們修改下對(duì)應(yīng)的backward()代碼:
????????#?如果傳遞過(guò)來(lái)的grad為空
????????if?grad?is?None:
????????????if?self.shape?==?():
????????????????#?設(shè)置梯度值為1,grad本身不需要計(jì)算梯度
????????????????self._grad?=?Tensor(1)
????????????else:
????????????????#?如果當(dāng)前Tensor得到不是標(biāo)量,那么grad必須制定
????????????????raise?RuntimeError("grad?must?be?specified?for?non?scalar")
????????else:
????????????self._grad?=?ensure_tensor(grad)
此時(shí)它也通過(guò)了:
test_add.py::test_array_add?PASSED???????????????????????????????????????[100%]
第三個(gè)測(cè)試方法又沒(méi)通過(guò):
#?t.shape要和grad.shape保持一致
>???????????????????assert?t.shape?==?g.shape,?f"grad?shape?must?match?tensor?shape?in?{self._ctx!r},?{g.shape!r}?!=?{t.shape!r}"
E???????????????????AssertionError:?grad?shape?must?match?tensor?shape?in?,?(2,?3)?!=?(3,)
說(shuō)的是,梯度形狀不一致的問(wèn)題。我們知道,梯度的形狀要和輸入保持一致。
對(duì)于 z = x + y,如果x.shape == y.shape,那么就像上面的例子一樣,沒(méi)什么問(wèn)題;
如果x.shape == (2,3) y.shape == (3,) 那么,根據(jù)廣播,先會(huì)在y左邊插入一個(gè)維度1,變成 -> y.shape == (1,3),接著,在第0個(gè)維度上進(jìn)行復(fù)制,使得新的維度 y.shape == (2,3)。這樣的話,對(duì)x求梯度時(shí),梯度要和x的shape保持一致;對(duì)y求梯度時(shí),也要和y的shape保持一致。
修復(fù)廣播帶來(lái)的問(wèn)題
由于要保證梯度的維度和輸入的維度一致,而最后得到的梯度是經(jīng)過(guò)了廣播操作的。所以,我們要實(shí)現(xiàn)廣播操作的逆操作:
def?unbroadcast(grad:?np.ndarray,?in_shape:?Tuple)?->?np.ndarray:
????'''
????廣播操作的逆操作,確保grad轉(zhuǎn)換成in_shape的形狀
????Args:
????????grad:?梯度
????????in_shape:?梯度要轉(zhuǎn)換的形狀
????Returns:
????'''
????#?首先計(jì)算維度個(gè)數(shù)之差
????ndims_added?=?grad.ndim?-?len(in_shape)
????#?由于廣播時(shí),先從左邊插入,再進(jìn)行復(fù)制,所以逆操作時(shí),也從左邊開(kāi)始,進(jìn)行復(fù)制的逆操作(求和)
????for?_?in?range(ndims_added):
????????#?在axis=0上進(jìn)行求和,去掉第0個(gè)維度,如果ndims_added?>?1,就需要不停的在第0個(gè)維度上面求和
????????grad?=?grad.sum(axis=0)
????return?grad
這樣,假設(shè)輸入的維度是,梯度的維度。那么上面的代碼,首先計(jì)算出維度個(gè)數(shù)差值為。
然后grad.sum(axis=0),把梯度的維度。此時(shí)剛好和輸入的維度一致。我們的這個(gè)測(cè)試用例應(yīng)該可以跑通。
test_add.py::test_broadcast_add?PASSED???????????????????????????????????[100%]
我們?cè)賹懸粋€(gè)測(cè)試方法:
def?test_broadcast_add2():
????x?=?Tensor(np.random.randn(2,?3),?requires_grad=True)??#?(2,3)
????y?=?Tensor(np.random.randn(1,?3),?requires_grad=True)??#?(1,3)
????z?=?x?+?y??#?(2,3)
????z.backward(Tensor(np.ones_like(x.data)))??#?grad.shape?==?z.shape
????assert?x.grad.data.tolist()?==?np.ones_like(x.data).tolist()
????assert?y.grad.data.tolist()?==?(np.ones_like(y.data)?*?2).tolist()
然后跑跑看:
>???????????????????assert?t.shape?==?g.shape,?f"grad?shape?must?match?tensor?shape?in?{self._ctx!r},?{g.shape!r}?!=?{t.shape!r}"
E???????????????????AssertionError:?grad?shape?must?match?tensor?shape?in?,?(2,?3)?!=?(1,?3)
..\..\core\tensor.py:190:?AssertionError
?? 又沒(méi)有跑通。說(shuō)是(2, 3) != (1, 3)。所以,我們不僅要比較維度個(gè)數(shù)的差值,還要看維度是否含有。
在理解廣播和常見(jiàn)的乘法中,我們知道廣播時(shí)的計(jì)算規(guī)律為:
首先讓所有輸入數(shù)組都向其中形狀最長(zhǎng)的數(shù)組看齊,形狀中不足的部分都通過(guò)在維度左邊加 1 補(bǔ)齊,然后比較對(duì)應(yīng)維度值,需要滿足:
它們是相等的 其他一個(gè)為1
所以我們也要考慮這種維度為的情況。更改后的代碼為:
def?unbroadcast(grad:?np.ndarray,?in_shape:?Tuple)?->?np.ndarray:
????'''
????廣播操作的逆操作,確保grad轉(zhuǎn)換成in_shape的形狀
????Args:
????????grad:?梯度
????????in_shape:?梯度要轉(zhuǎn)換的形狀
????Returns:
????'''
????#?首先計(jì)算維度個(gè)數(shù)之差
????ndims_added?=?grad.ndim?-?len(in_shape)
????#?由于廣播時(shí),先從左邊插入,再進(jìn)行復(fù)制,所以逆操作時(shí),也從左邊開(kāi)始,進(jìn)行復(fù)制的逆操作(求和)
????for?_?in?range(ndims_added):
????????#?在axis=0上進(jìn)行求和,去掉第0個(gè)維度,如果ndims_added?>?1,就需要不停的在第0個(gè)維度上面求和
????????grad?=?grad.sum(axis=0)
????#?處理?(2,3)?+?(1,3)?=>?(2,3)?grad的情況
????#?看in_shape中有沒(méi)有維度=1的情況
????for?i,?dim?in?enumerate(in_shape):
????????if?dim?==?1:
????????????#?那么需要在該axis上求和,并且保持維度?這里(2,3)?=>?(1,3)?grad?就和輸入維度保持一致了
????????????grad?=?grad.sum(axis=i,?keepdims=True)
????return?grad
我們?cè)谂芤幌聹y(cè)試用例:
test_add.py::test_broadcast_add2?PASSED??????????????????????????????????[100%]
至此,加法運(yùn)算反向傳播廣播帶來(lái)的問(wèn)題解決了。
上文說(shuō)過(guò),我們需要先寫一些測(cè)試用例。然后代碼提交都走PR,利用github的action機(jī)制,只有測(cè)試通過(guò)的PR,才能合入主分支。
我們先來(lái)通過(guò)github的action機(jī)制,在代碼提交時(shí)自動(dòng)跑所有的測(cè)試用例。
利用Github來(lái)進(jìn)行自動(dòng)測(cè)試
本文不會(huì)過(guò)多討論Github action實(shí)現(xiàn)細(xì)節(jié),感興趣的可以查詢Github官方文檔。
首先在項(xiàng)目根目錄下創(chuàng)建目錄.github/workflows,然后添加以下文件。

test.yaml:
name:?Unit?Test?Pipeline
#?當(dāng)在這些分支上提交時(shí),執(zhí)行這個(gè)workflow
on:
??push:
????branches:
??????-?'!master'???#?排除master,在其他分支提交代碼時(shí),需要進(jìn)行測(cè)試
??????-?'*'
#?一個(gè)workflow由一個(gè)或多個(gè)job組成
jobs:
??#?此workflow包含一個(gè)job,叫作test
??test:
????#?會(huì)在github提供的ubuntu系統(tǒng)上運(yùn)行測(cè)試代碼
????runs-on:?ubuntu-latest
????strategy:
??????matrix:
????????python-version:?[3.8,?3.9]?#?同時(shí)在這兩個(gè)版本的python上測(cè)試
????#?Steps?represent?a?sequence?of?tasks?that?will?be?executed?as?part?of?the?job
????steps:
??????#?首先下載代碼
??????-?uses:?actions/checkout@v2
??????-?name:?Set?up?Python?${{?matrix.python-version?}}
????????uses:?actions/setup-python@v2
????????with:
??????????python-version:?${{?matrix.python-version?}}?#?指定python版本
??????-?name:?Install?dependencies
????????run:?|
??????????python3?-m?pip?install?--upgrade?pip
??????????pip?install?pytest
??????????pip?install?-r?requirements.txt
??????-?name:?Run?unit?tests?#?跑測(cè)試
????????run:?|
??????????python3?-m?pytest
requirements.txt:
numpy==1.20.1
然后我們就創(chuàng)建一個(gè)分支,并且提交本文的相關(guān)改動(dòng),看能否觸發(fā)以及通過(guò)Github的測(cè)試。

編寫乘法測(cè)試用例
import?numpy?as?np
from?core.tensor?import?Tensor
def?test_simple_mul():
????'''
????測(cè)試簡(jiǎn)單的乘法
????'''
????x?=?Tensor(1,?requires_grad=True)
????y?=?2
????z?=?x?*?y
????z.backward()
????assert?x.grad.data?==?2.0
def?test_array_mul():
????'''
????測(cè)試兩個(gè)同shape的向量乘法
????'''
????x?=?Tensor([1,?2,?3],?requires_grad=True)
????y?=?Tensor([4,?5,?6],?requires_grad=True)
????z?=?x?*?y
????#?對(duì)應(yīng)元素相乘
????assert?z.data.tolist()?==?[4,?10,?18]
????z.backward(Tensor([1,?1,?1]))
????assert?x.grad._data.tolist()?==?y.data.tolist()
????assert?y.grad._data.tolist()?==?x.data.tolist()
????x?*=?0.1
????assert?x.grad?is?None
????#?assert?[0.10000000149011612,?0.20000000298023224,?0.30000001192092896]?==?[0.1,?0.2,?0.3]
????#?assert?x.data.tolist()?==?[0.1,?0.2,?0.3]
????#?需要用近似相等來(lái)判斷
????np.testing.assert_array_almost_equal(x.data,?[0.1,?0.2,?0.3])
def?test_broadcast_mul():
????x?=?Tensor([[1,?2,?3],?[4,?5,?6]],?requires_grad=True)??#?(2,?3)
????y?=?Tensor([7,?8,?9],?requires_grad=True)??#?(3,?)
????z?=?x?*?y??#?(2,3)?*?(3,)?=>?(2,3)?*?(2,3)?->?(2,3)
????assert?z.data.tolist()?==?[[7,?16,?27],?[28,?40,?54]]
????z.backward(Tensor([[1,?1,?1,?],?[1,?1,?1]]))
????
????assert?x.grad.data.tolist()?==?[[7,?8,?9],?[7,?8,?9]]
????assert?y.grad.data.tolist()?==?[5,?7,?9]
E???????????????????AssertionError:?grad?shape?must?match?tensor?shape?in?,?(2,?3)?!=?(3,)
類似加法,對(duì)于乘法我們也要處理廣播導(dǎo)致的梯度維度問(wèn)題。然后再進(jìn)行測(cè)試:
test_mul.py::test_broadcast_mul?PASSED???????????????????????????????????[100%]
總結(jié)
本文我們優(yōu)化了Tensor反向傳播的代碼,同時(shí)修復(fù)了加法和乘法在發(fā)生廣播時(shí)遇到的一些問(wèn)題。由于我們的代碼托管在Github上,所以可以利用GitHub的一些機(jī)制進(jìn)行自動(dòng)測(cè)試。
下篇文章就可以根據(jù)計(jì)算圖實(shí)現(xiàn)其他運(yùn)算的代碼了。
最后一句:BUG,走你!


Markdown筆記神器Typora配置Gitee圖床
不會(huì)真有人覺(jué)得聊天機(jī)器人難吧(一)
Spring Cloud學(xué)習(xí)筆記(一)
沒(méi)有人比我更懂Spring Boot(一)
入門人工智能必備的線性代數(shù)基礎(chǔ)
1.看到這里了就點(diǎn)個(gè)在看支持下吧,你的在看是我創(chuàng)作的動(dòng)力。
2.關(guān)注公眾號(hào),每天為您分享原創(chuàng)或精選文章!
3.特殊階段,帶好口罩,做好個(gè)人防護(hù)。
