從零實現(xiàn)深度學(xué)習(xí)框架(四)實現(xiàn)自己的Tensor對象

引言
本著“凡我不能創(chuàng)造的,我就不能理解”的思想,本系列文章會基于純Python以及NumPy從零創(chuàng)建自己的深度學(xué)習(xí)框架,該框架類似PyTorch能實現(xiàn)自動求導(dǎo)。
要深入理解深度學(xué)習(xí),從零開始創(chuàng)建的經(jīng)驗非常重要,從自己可以理解的角度出發(fā),盡量不適用外部完備的框架前提下,實現(xiàn)我們想要的模型。本系列文章的宗旨就是通過這樣的過程,讓大家切實掌握深度學(xué)習(xí)底層實現(xiàn),而不是僅做一個調(diào)包俠。
本文基于前面介紹的計算圖知識,開始實現(xiàn)我們自己的深度學(xué)習(xí)框架。
就像PyTorch用Tensor來表示張量一樣,我們也創(chuàng)建一個自己的Tensor。
數(shù)據(jù)類型
由于我們自己的Tensor也需要進(jìn)行矩陣運算,因此我們直接封裝最常用的矩陣運算工具——NumPy。
首先,我們增加幫助函數(shù)來確保用到的數(shù)據(jù)類型為np.ndarray。
#?默認(rèn)數(shù)據(jù)類型
_type?=?np.float32
#?可以轉(zhuǎn)換為Numpy數(shù)組的類型
Arrayable?=?Union[float,?list,?np.ndarray]
def?ensure_array(arrayable:?Arrayable)?->?np.ndarray:
????"""
????:param?arrayable:
????:return:
????"""
????if?isinstance(arrayable,?np.ndarray):
????????#?如果本身是ndarray
????????return?arrayable
????#?轉(zhuǎn)換為Numpy數(shù)組
????return?np.array(arrayable,?dtype=_type)
Tensor初探
所有的代碼都盡量添加類型提示(Typing),已增加代碼的可讀性。接下來,創(chuàng)建我們自己的Tensor實現(xiàn):
class?Tensor:
????def?__init__(self,?data:?Arrayable,?requires_grad:?bool?=?False)?->?None:
????????'''
????????初始化Tensor對象
????????Args:
????????????data:?數(shù)據(jù)
????????????requires_grad:?是否需要計算梯度
????????'''
????????#?data?是?np.ndarray
????????self._data?=?ensure_array(data)
????????self.requires_grad?=?requires_grad
????????#?保存該Tensor的梯度
????????self._grad?=?None
????????if?self.requires_grad:
????????????#?初始化梯度
????????????self.zero_grad()
????????#?用于計算圖的內(nèi)部變量
????????self._ctx?=?None
調(diào)用ensure_array確保傳過來的是一個Numpy數(shù)組。requires_grad表示是否需要計算梯度。
下面增加一些屬性方法(屬于上面Tensor類):
????@property
????def?grad(self):
????????return?self._grad
????@property
????def?data(self)?->?np.ndarray:
????????return?self._data
[email protected]
????def?data(self,?new_data:?np.ndarray)?->?None:
????????self._data?=?ensure_array(new_data)
????????#?重新賦值后就沒有梯度了
????????self._grad?=?None
通過@property來確保梯度是只讀的,同時讓保存的數(shù)據(jù)data是可讀可寫的,當(dāng)修改data時,需要清空梯度。因為綁定的數(shù)據(jù)已經(jīng)發(fā)生了變化。
我們知道Tensor作為張量,它是有形狀(shape)、維度(dimension)等相關(guān)屬性的,下面我們就來實現(xiàn):
????#?****一些常用屬性****
????@property
????def?shape(self)?->?Tuple:
????????'''返回Tensor各維度大小的元素'''
????????return?self.data.shape
????@property
????def?ndim(self)?->?int:
????????'''返回Tensor的維度個數(shù)'''
????????return?self.data.ndim
????@property
????def?dtype(self)?->?np.dtype:
????????'''返回Tensor中數(shù)據(jù)的類型'''
????????return?self.data.dtype
????@property
????def?size(self)?->?int:
????????'''
????????返回Tensor中元素的個數(shù)?等同于np.prod(a.shape)
????????Returns:
????????'''
????????return?self.data.size
在Tensor的初始化方法中,有進(jìn)行梯度初始化的方法,看一下是如何實現(xiàn)的:
????def?zero_grad(self)?->?None:
????????'''
????????將梯度初始化為0
????????Returns:
????????'''
????????self._grad?=?Tensor(np.zeros_like(self.data,?dtype=_type))
為了方便調(diào)試,我們實現(xiàn)了了__repr__方法。同時實現(xiàn)__len_魔法方法,返回數(shù)據(jù)的長度。
????def?__repr__(self)?->?str:
????????return?f"Tensor({self.data},?requires_grad={self.requires_grad})"
????def?__len__(self)?->?int:
????????return?len(self.data)
最后,實現(xiàn)兩個比較有用的方法。
????def?assign(self,?x)?->?"Tensor":
????????'''將x的值賦予當(dāng)前Tensor'''
????????x?=?ensure_tensor(x)
????????#?維度必須一致
????????assert?x.shape?==?self.shape
????????self.data?=?x.data
????????return?self
????def?numpy(self)?->?np.ndarray:
????????"""轉(zhuǎn)換為Numpy數(shù)組"""
????????return?self.data
assign用于給當(dāng)前Tensor賦值,因為我們上面讓data是只讀了,所以需要額外提供這個方法。
numpy則是將當(dāng)前Tensor對象轉(zhuǎn)換為Numpy數(shù)組。
類似ensure_array,我們也提供了一個確保為Tensor的幫助方法。
Tensorable?=?Union["Tensor",?float,?np.ndarray]
def?ensure_tensor(tensoralbe:?Tensorable)?->?"Tensor":
????'''
????確保是Tensor對象
????'''
????if?isinstance(tensoralbe,?Tensor):
????????return?tensoralbe
????return?Tensor(tensoralbe)
測試
寫完代碼進(jìn)行測試是一個好習(xí)慣,我們今天暫且在__main__里面測試:
if?__name__?==?'__main__':
????t?=?Tensor(range(10))
????print(t)
????print(t.shape)
????print(t.size)
????print(t.dtype)
輸出:
Tensor([0.?1.?2.?3.?4.?5.?6.?7.?8.?9.],?requires_grad=False)
(10,)
10
float32
完整代碼
完整代碼筆者上傳到了程序員最大交友網(wǎng)站上去了,地址:??? https://github.com/nlp-greyfoss/metagrad
總結(jié)
本文我們實現(xiàn)了Tensor對象的基本框架,下篇文章就會學(xué)習(xí)如何實現(xiàn)基本的反向傳播。
最后一句:BUG,走你!


Markdown筆記神器Typora配置Gitee圖床
不會真有人覺得聊天機器人難吧(一)
Spring Cloud學(xué)習(xí)筆記(一)
沒有人比我更懂Spring Boot(一)
入門人工智能必備的線性代數(shù)基礎(chǔ)
1.看到這里了就點個在看支持下吧,你的在看是我創(chuàng)作的動力。
2.關(guān)注公眾號,每天為您分享原創(chuàng)或精選文章!
3.特殊階段,帶好口罩,做好個人防護(hù)。
