<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          從零實(shí)現(xiàn)深度學(xué)習(xí)框架(五)實(shí)現(xiàn)Tensor的反向傳播

          共 6676字,需瀏覽 14分鐘

           ·

          2021-12-25 20:26

          橫屏觀看,效果更佳!更多文章請(qǐng)關(guān)注公眾號(hào)!

          更多精彩推薦,請(qǐng)關(guān)注我們

          引言

          本著“凡我不能創(chuàng)造的,我就不能理解”的思想,本系列文章會(huì)基于純Python以及NumPy從零創(chuàng)建自己的深度學(xué)習(xí)框架,該框架類(lèi)似PyTorch能實(shí)現(xiàn)自動(dòng)求導(dǎo)。

          要深入理解深度學(xué)習(xí),從零開(kāi)始創(chuàng)建的經(jīng)驗(yàn)非常重要,從自己可以理解的角度出發(fā),盡量不適用外部完備的框架前提下,實(shí)現(xiàn)我們想要的模型。本系列文章的宗旨就是通過(guò)這樣的過(guò)程,讓大家切實(shí)掌握深度學(xué)習(xí)底層實(shí)現(xiàn),而不是僅做一個(gè)調(diào)包俠。

          常見(jiàn)運(yùn)算的計(jì)算圖中,我們了解了加減乘除等運(yùn)算的計(jì)算圖。本文通過(guò)代碼實(shí)現(xiàn)加法和乘法的計(jì)算圖來(lái)了解我們的Tensor自動(dòng)反向傳播計(jì)算梯度的模式。

          實(shí)現(xiàn)運(yùn)算基類(lèi)

          我們是一個(gè)仿PyTorch的自動(dòng)求導(dǎo)深度學(xué)習(xí)框架,為什么要仿PyTorch呢?因?yàn)樗娴姆浅:糜?。而且在這個(gè)過(guò)程會(huì)參考一些PyTorch的實(shí)現(xiàn),這也會(huì)有利于我們對(duì)PyTorch的理解。

          在文章EXTENDING PYTORCH(https://pytorch.org/docs/stable/notes/extending.html)中,介紹了如何在PyTorch中增加新的操作(operation),(1)首先要做的便是創(chuàng)建一個(gè)新的Function的子類(lèi)并實(shí)現(xiàn)forward()backward()方法;(2)然后,調(diào)用ctx參數(shù)上的合適方法;

          forward()是進(jìn)行真正運(yùn)算的代碼,它可以接收任意多的參數(shù)。backward()定義了梯度公式,通常有多少個(gè)輸入,就得返回多少個(gè)相應(yīng)的梯度。但是,有時(shí)并不是所有的參數(shù)都需要計(jì)算梯度,比如切片(Slice)參數(shù)。那么我們可以在相應(yīng)的位置返回None,或者設(shè)置needs_input_grad對(duì)應(yīng)位置為False。

          實(shí)現(xiàn)者需要正確使用forward()ctx中的函數(shù),以確保新函數(shù)的自動(dòng)求導(dǎo)能正確工作:

          • 當(dāng)需要保存forward()中輸入或輸出Tensor以在backward()中使用時(shí)需要調(diào)用save_for_backward()方法。在前向傳播時(shí),建議調(diào)用apply()方法而不是forward()方法。
          • mark_non_differentiable()用于表明某個(gè)輸出不需要計(jì)算梯度。默認(rèn)所有的輸出Tensor只要是可導(dǎo)類(lèi)型都設(shè)置為需要計(jì)算梯度。

          以上節(jié)選自PyTorch官方文檔的內(nèi)容,雖然看起來(lái)好像并不復(fù)雜,但是完全照抄的話還是有些麻煩。我們的實(shí)現(xiàn)當(dāng)然沒(méi)有這么復(fù)雜,我們也有forward()backward()靜態(tài)方法,不需要計(jì)算梯度的參數(shù),我們暫且返回None就好了。

          class?_Function:
          ????def?__init__(self,?*tensors:?"Tensor")?->?None:
          ????????#?該操作所依賴的所有輸入
          ????????self.depends_on?=?[t?for?t?in?tensors]
          ????????#?保存需要在backward()中使用的Tensor或其他對(duì)象(如Shape)
          ????????self.saved_tensors?=?[]

          ????def?__new__(cls,?*args,?**kwargs):
          ????????'''__new__是靜態(tài)方法,當(dāng)該類(lèi)被實(shí)例化時(shí)調(diào)用'''
          ????????#?把這兩個(gè)方法轉(zhuǎn)換為靜態(tài)方法,我們可以通過(guò)類(lèi)名直接調(diào)用
          ????????cls.forward?=?staticmethod(cls.forward)
          ????????cls.backward?=?staticmethod(cls.backward)
          ????????return?super().__new__(cls)

          ????def?save_for_backward(ctx,?*x:?Any)?->?None:
          ????????ctx.saved_tensors.extend(x)

          ????def?forward(ctx,?*args:?Any,?**kwargs:?Any)?->?np.ndarray:
          ????????'''前向傳播,進(jìn)行真正運(yùn)算的地方'''
          ????????raise?NotImplementedError("You?must?implement?the?forward?function?for?custom?Function.")

          ????def?backward(ctx,?grad:?Any)?->?Any:
          ????????'''實(shí)現(xiàn)反向傳播,計(jì)算梯度'''
          ????????raise?NotImplementedError("You?must?implement?the?backward?method?for?your?custom?Function?"
          ??????????????????????????????????"to?use?it?with?backward?mode?AD.")

          ????def?apply(self,?ctx,?*xs:?"Tensor",?**kwargs)?->?"Tensor":
          ????????'''與PyTorch一樣,我們也不直接調(diào)用forward,而是調(diào)用此方法'''
          ????????#?[t.data for t in xs]遍歷Tensor中的data(np.ndarray)值,參與實(shí)際計(jì)算的都是NumPy的數(shù)組。
          ????????ret?=?Tensor(self.forward(ctx,?*[t.data?for?t?in?xs],?**kwargs),
          ?????????????????????requires_grad=any([t.requires_grad?for?t?in?xs]))

          ????????if?ret.requires_grad:
          ????????????ret._ctx?=?ctx

          ????????return?ret

          我們先定義好自己的_Function。然后根據(jù)常見(jiàn)運(yùn)算的計(jì)算圖先實(shí)現(xiàn)簡(jiǎn)單的加減乘除。

          實(shí)現(xiàn)加法運(yùn)算

          class?Add(_Function):

          ????def?forward(ctx,?x:?np.ndarray,?y:?np.ndarray)?->?np.ndarray:
          ????????'''
          ????????實(shí)現(xiàn)?z?=?x?+?y?,我們這里的x和y都是Numpy數(shù)組,因此可能發(fā)生廣播,
          ????????在實(shí)現(xiàn)反向傳播是需要注意
          ????????'''

          ????????#?進(jìn)行真正的運(yùn)算
          ????????return?x?+?y

          ????def?backward(ctx,?grad:?Any)?->?Any:
          ????????#?輸入有兩個(gè),都是需要計(jì)算梯度的,因此輸出也是兩個(gè)
          ????????return?grad,?grad

          加法運(yùn)算,流到的梯度為,就是上面代碼中的grad。

          實(shí)現(xiàn)乘法運(yùn)算

          class?Mul(_Function):

          ????def?forward(ctx,?x:?np.ndarray,?y:?np.ndarray)?->?np.ndarray:
          ????????'''
          ????????實(shí)現(xiàn)?z?=?x?*?y
          ????????'''

          ????????#?乘法需要保存輸入x和y,用于反向傳播
          ????????ctx.save_for_backward(x,?y)
          ????????return?x?*?y

          ????def?backward(ctx,?grad:?Any)?->?Any:
          ????????x,?y?=?ctx.saved_tensors
          ????????#?分別返回?L/?x?和??L/?y
          ????????return?grad?*?y,?grad?*?x

          根據(jù)乘法的計(jì)算圖,實(shí)現(xiàn)起來(lái)也比較簡(jiǎn)單。

          加法和乘法實(shí)現(xiàn)好了,我們下面看如何結(jié)合計(jì)算圖的知識(shí)通過(guò)代碼實(shí)現(xiàn)它們的反向傳播。

          實(shí)現(xiàn)反向傳播

          使用過(guò)PyTorch的童鞋知道,只需要在Tensor上調(diào)用backward()就能計(jì)算梯度。

          本小節(jié),我們也來(lái)實(shí)現(xiàn)這樣的功能。

          自動(dòng)求導(dǎo)神器計(jì)算圖中,我們其實(shí)已經(jīng)看到了如何實(shí)現(xiàn)了。下面通過(guò)代碼來(lái)描述它們。

          和之前介紹的例子一樣,我們也以e = ( a + b ) ? ( b + 1 )為例,期望調(diào)用e.backward()就能得到?ab的梯度grad。

          自動(dòng)求導(dǎo)神器計(jì)算圖中,我們了解了反向模式。我們這里實(shí)現(xiàn)的當(dāng)然就是這種高效的方式。

          Tensor中添加以下方法:

          ????"""
          ?????backward函數(shù)現(xiàn)在應(yīng)該從當(dāng)前節(jié)點(diǎn)(Tensor)回溯到所有依賴節(jié)點(diǎn)(depends_on),計(jì)算路徑上的偏導(dǎo)
          ????????#?我們分為兩部分
          ????????#?a)?遍歷計(jì)算圖
          ????????#????如果c是a經(jīng)過(guò)某個(gè)函數(shù)的結(jié)果(?c=f(a)?),我們無(wú)法知道a的梯度,直到我們得到了c的梯度(鏈?zhǔn)椒▌t)
          ????????#????所以我們需要逆序計(jì)算圖中的拓?fù)浣Y(jié)構(gòu)(reverse?mode),相當(dāng)沿著有向圖的←方向(從指向節(jié)點(diǎn)到起始節(jié)點(diǎn))進(jìn)行計(jì)算
          ????????#?b)?應(yīng)用梯度
          ????????#????現(xiàn)在我們能訪問(wèn)到每個(gè)node,我們用它的backward函數(shù)將梯度傳遞給它們的depends_on
          ????"""


          ????def?_rev_topo_sort(self):
          ????????'''
          ????????a)?遍歷計(jì)算圖,逆序計(jì)算圖中的拓?fù)浣Y(jié)構(gòu)
          ????????Returns:
          ????????'''


          ????????def?visit(node,?visited,?nodes):
          ????????????#?標(biāo)記為已訪問(wèn)
          ????????????visited.add(node)
          ????????????if?node._ctx:
          ????????????????#?遍歷所有依賴節(jié)點(diǎn),遞歸調(diào)用visit
          ????????????????[visit(nd,?visited,?nodes)?for?nd?in?node._ctx.depends_on?if?nd?not?in?visited]
          ????????????????#?遞歸調(diào)用結(jié)束后將node入nodes
          ????????????????nodes.append(node)
          ????????????#?返回遍歷結(jié)果
          ????????????return?nodes

          ????????return?reversed(visit(self,?set(),?[]))

          反向模式的計(jì)算順序相當(dāng)于逆序計(jì)算圖中的拓?fù)浣Y(jié)構(gòu)。我們以e = ( a + b ) ? ( b + 1 )為例,打印該函數(shù)的輸出看。

          if?__name__?==?'__main__':
          ????a,?b?=?Tensor(2,?requires_grad=True),?Tensor(1,?requires_grad=True)
          ????e?=?(a?+?b)?*?(b?+?1)
          ????print(list(e._rev_topo_sort()))?
          [Tensor(6.0,?requires_grad=True),?Tensor(2.0,?requires_grad=True),?Tensor(3.0,?requires_grad=True)]
          計(jì)算圖—前向傳播

          從上面的輸出結(jié)合這張計(jì)算圖來(lái)看,梯度由分別流向了

          我們基于這種反向模式,來(lái)實(shí)現(xiàn)backward()方法。

          ????def?backward(self,?grad:?"Tensor"?=?None)?->?None:
          ????????'''
          ????????實(shí)現(xiàn)Tensor的反向傳播
          ????????Args:
          ????????????grad:?如果該Tensor不是標(biāo)量,則需要傳遞梯度進(jìn)來(lái)

          ????????Returns:

          ????????'''

          ????????#?只能在requires_grad=True的Tensor上調(diào)用此方法
          ????????assert?self.requires_grad,?"called?backward?on?tensor?do?not?require?grad"

          ????????self._grad?=?grad
          ????????#?如果傳遞過(guò)來(lái)的grad為空
          ????????if?grad?is?None:
          ????????????if?self.shape?==?():
          ????????????????#?設(shè)置梯度值為1,grad本身不需要計(jì)算梯度
          ????????????????self._grad?=?Tensor(1)

          ????????for?t?in?self._rev_topo_sort():
          ????????????assert?t.grad?is?not?None
          ????????????#?以逆序計(jì)算梯度,調(diào)用t相關(guān)運(yùn)算操作的backward靜態(tài)方法
          ????????????#?計(jì)算流向其依賴節(jié)點(diǎn)上的梯度(流向其下游)
          ????????????grads?=?t._ctx.backward(t._ctx,?t.grad.data)
          ????????????#?如果只依賴一個(gè)輸入,我們也通過(guò)列表來(lái)封裝,防止zip將其繼續(xù)拆分
          ????????????if?len(t._ctx.depends_on)?==?1:
          ????????????????grads?=?[grads]

          ????????????for?t,?g?in?zip(t._ctx.depends_on,?grads):
          ????????????????#?計(jì)算其下游節(jié)點(diǎn)上的累積梯度,因?yàn)榭赡苡卸鄺l邊
          ????????????????if?t.requires_grad?and?g?is?not?None:
          ????????????????????#?t.shape要和grad.shape保持一致
          ????????????????????assert?t.shape?==?g.shape,?f"grad?shape?must?match?tensor?shape?in?{self._ctx!r},?{g.shape!r}?!=?{t.shape!r}"
          ????????????????????#?grad?Tensor
          ????????????????????gt?=?Tensor(g)
          ????????????????????t._grad?=?gt?if?t.grad?is?None?else?t.grad?+?gt

          下面我們先寫(xiě)出計(jì)算式子,然后像PyTorch一樣直接調(diào)用backward,看能否計(jì)算出對(duì)應(yīng)節(jié)點(diǎn)上的梯度。

          if?__name__?==?'__main__':
          ????a,?b?=?Tensor(2,?requires_grad=True),?Tensor(1,?requires_grad=True)
          ????e?=?(a?+?b)?*?(b?+?1)
          ????e.backward()
          ????print(f'grad?of?a:{a.grad}')
          ????print(f'grad?of?b:{b.grad}')
          grad?of?a:Tensor(2.0,?requires_grad=False)
          grad?of?b:Tensor(5.0,?requires_grad=False)

          完整代碼

          完整代碼筆者上傳到了程序員最大交友網(wǎng)站上去了,地址:??? https://github.com/nlp-greyfoss/metagrad

          總結(jié)

          本文我們實(shí)現(xiàn)了Tensor的反向傳播框架,并實(shí)現(xiàn)了加法和乘法的計(jì)算圖。


          最后一句:BUG,走你!

          Markdown筆記神器Typora配置Gitee圖床
          不會(huì)真有人覺(jué)得聊天機(jī)器人難吧(一)
          Spring Cloud學(xué)習(xí)筆記(一)
          沒(méi)有人比我更懂Spring Boot(一)
          入門(mén)人工智能必備的線性代數(shù)基礎(chǔ)

          1.看到這里了就點(diǎn)個(gè)在看支持下吧,你的在看是我創(chuàng)作的動(dòng)力。
          2.關(guān)注公眾號(hào),每天為您分享原創(chuàng)或精選文章!
          3.特殊階段,帶好口罩,做好個(gè)人防護(hù)。



          瀏覽 89
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国产免费一区二区三区四区午夜视频 | 超碰碰97 | 996re热精品视频 | 亚洲无码影音先锋 | 91精品在鸭窝精在线观看不卡 |