從零實現(xiàn)深度學習框架(九)實現(xiàn)常見運算的計算圖(下)

引言
本著“凡我不能創(chuàng)造的,我就不能理解”的思想,本系列文章會基于純Python以及NumPy從零創(chuàng)建自己的深度學習框架,該框架類似PyTorch能實現(xiàn)自動求導。
要深入理解深度學習,從零開始創(chuàng)建的經(jīng)驗非常重要,從自己可以理解的角度出發(fā),盡量不適用外部完備的框架前提下,實現(xiàn)我們想要的模型。本系列文章的宗旨就是通過這樣的過程,讓大家切實掌握深度學習底層實現(xiàn),而不是僅做一個調(diào)包俠。
在上篇文章中,我們實現(xiàn)了常見運算的計算題,本文來實現(xiàn)剩下的:Max、Slice、Reshape和Transpose的計算圖。
求最大值
還是先寫測試用例:
from?core.tensor?import?Tensor
import?numpy?as?np
def?test_simple_max():
????x?=?Tensor([1,?2,?3,?6,?7,?9,?2],?requires_grad=True)
????z?=?x.max()
????assert?z.data?==?[9]
????z.backward()
????assert?x.grad.data.tolist()?==?[0,?0,?0,?0,?0,?1,?0]
def?test_simple_max2():
????x?=?Tensor([1,?2,?3,?9,?7,?9,?2],?requires_grad=True)
????z?=?x.max()
????assert?z.data?==?[9]??#?最大值還是9
????z.backward()
????#?但是有兩個最大值,所以梯度被均分了
????assert?x.grad.data.tolist()?==?[0,?0,?0,?0.5,?0,?0.5,?0]
def?test_matrix_max():
????a?=?np.array([[1.,?1.,?8.,?9.,?1.],
??????????????????[4.,?5.,?9.,?9.,?8.],
??????????????????[8.,?6.,?9.,?7.,?9.],
??????????????????[8.,?6.,?1.,?9.,?8.]])
????x?=?Tensor(a,?requires_grad=True)
????z?=?x.max()
????assert?z.data?==?[9]??#?最大值是9
????z.backward()
????#?總共有6個9
????np.testing.assert_array_almost_equal(x.grad.data,?[[0,?0,?0,?1?/?6,?0],
???????????????????????????????????????????????????????[0,?0,?1?/?6,?1?/?6,?0],
???????????????????????????????????????????????????????[0,?0,?1?/?6,?0,?1?/?6],
???????????????????????????????????????????????????????[0,?0,?0,?1?/?6,?0]])
def?test_matrix_max2():
????a?=?np.array([[1.,?1.,?8.,?9.,?1.],
??????????????????[4.,?5.,?9.,?9.,?8.],
??????????????????[8.,?6.,?9.,?7.,?9.],
??????????????????[8.,?6.,?1.,?9.,?8.]])
????x?=?Tensor(a,?requires_grad=True)
????z?=?x.max(axis=0)??#?[8,?6,?9,?9,?9]
????assert?z.data.tolist()?==?[8,?6,?9,?9,?9]
????z.backward([1,?1,?1,?1,?1])
????grad?=?[[0.,?0.,?0.,?1?/?3,?0.],
????????????[0.,?0.,?0.5,?1?/?3,?0.],
????????????[0.5,?0.5,?0.5,?0,?1],
????????????[0.5,?0.5,?0.,?1?/?3,?0.]]
????np.testing.assert_array_almost_equal(x.grad.data,?np.array(grad))
分析的代碼在文章計算圖運算補充中,這里就不再贅述。

class?Max(_Function):
????def?forward(ctx,?x:?ndarray,?axis=None,?keepdims=False)?->?ndarray:
????????ret?=?np.amax(x,?axis=axis,?keepdims=keepdims)
????????ctx.save_for_backward(x,?axis,?ret,?keepdims)
????????return?ret
????def?backward(ctx,?grad:?ndarray)?->?ndarray:
????????x,?axis,?ret,?keepdims?=?ctx.saved_tensors
????????mask?=?(x?==?ret)
????????div?=?mask.sum(axis=axis,?keepdims=keepdims)
????????return?mask?*?grad?/?div
切片
切片就是索引操作,測試代碼如下:
from?core.tensor?import?Tensor
import?numpy?as?np
def?test_get_by_index():
????x?=?Tensor([1,?2,?3,?4,?5,?6,?7],?requires_grad=True)
????z?=?x[2]
????assert?z.data?==?3
????z.backward()
????assert?x.grad.data.tolist()?==?[0,?0,?1,?0,?0,?0,?0]
def?test_slice():
????x?=?Tensor([1,?2,?3,?4,?5,?6,?7],?requires_grad=True)
????z?=?x[2:4]
????assert?z.data.tolist()?==?[3,?4]
????z.backward([1,?1])
????assert?x.grad.data.tolist()?==?[0,?0,?1,?1,?0,?0,?0]
def?test_matrix_slice():
????a?=?np.array([[1.,?1.,?8.,?9.,?1.],
??????????????????[4.,?5.,?9.,?9.,?8.],
??????????????????[8.,?6.,?9.,?7.,?9.],
??????????????????[8.,?6.,?1.,?9.,?8.]])
????x?=?Tensor(a,?requires_grad=True)
????z?=?x[1:3,?2:4]
????assert?z.data.tolist()?==?[[9,?9],?[9,?7]]
????z.backward([[1,?1],?[1,?1]])
????#?總共有6個9
????np.testing.assert_array_almost_equal(x.grad.data,?[[0,?0,?0,?0,?0],
???????????????????????????????????????????????????????[0,?0,?1,?1,?0],
???????????????????????????????????????????????????????[0,?0,?1,?1,?0],
???????????????????????????????????????????????????????[0,?0,?0,?0,?0]])

class?Slice(_Function):
????def?forward(ctx,?x:?ndarray,?idxs:?slice)?->?ndarray:
????????'''
????????z?=?x[idxs]
????????'''
????????#?如果傳入[1:3],變成切片slice
????????#?如果idxs傳入單個索引,會被看成是整數(shù),所以這里轉(zhuǎn)換回來
????????if?isinstance(idxs,?ndarray):
????????????idxs?=?int(idxs.item())
????????ctx.save_for_backward(x.shape,?idxs)
????????return?x[idxs]
????def?backward(ctx,?grad)?->?Tuple[ndarray,?None]:
????????x_shape,?idxs?=?ctx.saved_tensors
????????bigger_grad?=?np.zeros(x_shape,?dtype=grad.dtype)
????????bigger_grad[idxs]?=?grad
????????return?bigger_grad,?None
變形
變形(Reshape)操作的反向傳播其實是最簡單的。假設經(jīng)過y = x.reshape(..),在反向傳播時,只要保證梯度的形狀和x保持一致即可。
測試用例:
import?numpy?as?np
from?core.tensor?import?Tensor
def?test_reshape():
????x?=?Tensor(np.arange(9),?requires_grad=True)
????z?=?x.reshape((3,?3))
????z.backward(np.ones((3,?3)))
????assert?x.grad.data.tolist()?==?np.ones_like(x.data).tolist()
def?test_matrix_reshape():
????x?=?Tensor(np.arange(12).reshape(2,?6),?requires_grad=True)
????z?=?x.reshape((4,?3))
????z.backward(np.ones((4,?3)))
????assert?x.grad.data.tolist()?==?np.ones_like(x.data).tolist()
代碼實現(xiàn):
class?Reshape(_Function):
????def?forward(ctx,?x:?ndarray,?shape:?Tuple)?->?ndarray:
????????ctx.save_for_backward(x.shape)
????????return?x.reshape(shape)
????def?backward(ctx,?grad:?ndarray)?->?Tuple[ndarray,?None]:
????????x_shape,?=?ctx.saved_tensors
????????return?grad.reshape(x_shape),?None
轉(zhuǎn)置
變形就是Reshape操作,在計算圖運算補充中中,我們詳細分析了變形和轉(zhuǎn)置的區(qū)別。
比如

經(jīng)過變形后:

轉(zhuǎn)置:

我們實現(xiàn)測試用例:
import?numpy?as?np
from?core.tensor?import?Tensor
def?test_transpose():
????x?=?Tensor(np.arange(6).reshape((2,?3)),?requires_grad=True)
????z?=?x.T
????assert?z.data.shape?==?(3,?2)
????z.backward(np.ones((3,?2)))
????assert?x.grad.data.tolist()?==?np.ones_like(x.data).tolist()
def?test_matrix_transpose():
????x?=?Tensor(np.arange(12).reshape((2,?6,?1)),?requires_grad=True)
????z?=?x.transpose((0,?2,?1))
????assert?z.data.shape?==?(2,?1,?6)
????z.backward(np.ones((2,?1,?6)))
????assert?x.grad.data.tolist()?==?np.ones_like(x.data).tolist()
代碼實現(xiàn):
class?Transpose(_Function):
????def?forward(ctx,?x:?ndarray,?axes)?->?ndarray:
????????ctx.save_for_backward(axes)
????????return?x.transpose(axes)
????def?backward(ctx,?grad:?ndarray)?->?Any:
????????axes,?=?ctx.saved_tensors
????????if?axes?is?None:
????????????return?grad.transpose()
????????return?grad.transpose(tuple(np.argsort(axes))),?None
完整代碼
完整代碼筆者上傳到了程序員最大交友網(wǎng)站上去了,地址: ?? ?https://github.com/nlp-greyfoss/metagrad
總結
到此,基本上我們會用到基本運算的計算圖都實現(xiàn)了。從下篇文章開始,就基于我們的自動求導工具來實現(xiàn)深度學習模型了。
最后一句:BUG,走你!


Markdown筆記神器Typora配置Gitee圖床
不會真有人覺得聊天機器人難吧(一)
Spring Cloud學習筆記(一)
沒有人比我更懂Spring Boot(一)
入門人工智能必備的線性代數(shù)基礎
1.看到這里了就點個在看支持下吧,你的在看是我創(chuàng)作的動力。
2.關注公眾號,每天為您分享原創(chuàng)或精選文章!
3.特殊階段,帶好口罩,做好個人防護。
