5分鐘玩轉PyTorch | 張量廣播計算的本質是什么?
AI因你而升溫,記得加星標哦!
↑?關注 + 星標?,每天學Python新技能
后臺回復【大禮包】送你Python自學大禮包
PyTorch中的張量具有和NumPy相同的廣播特性,允許不同形狀的張量之間進行計算。
廣播的實質特性,其實是低維向量映射到高維之后,相同位置再進行相加。我們重點要學會的就是低維向量如何向高維向量進行映射。
相同形狀的張量計算
雖然我們覺得不同形狀之間的張量計算才是廣播,但其實相同形狀的張量計算本質上也是廣播。
t1?=?torch.arange(3)
t1
#?tensor([0,?1,?2])
#?對應位置元素相加
t1?+?t1
#?tensor([0,?2,?4])
與Python對比
如果兩個list相加,結果是什么?
a?=?[0,?1,?2]
a?+?a
#?[0,?1,?2,?0,?1,?2]
不同形狀的張量計算
廣播的特性是不同形狀的張量進行計算時,一個或多個張量通過隱式轉化成相同形狀的兩個張量,從而完成計算。
但并非任意兩個不同形狀的張量都能進行廣播,因此我們要掌握廣播隱式轉化的核心依據。
2.1 標量和任意形狀的張量
標量(零維張量)可以和任意形狀的張量進行計算,計算過程就是標量和張量的每一個元素進行計算。
#?標量與一維向量
t1?=?torch.arange(3)
#?tensor([0,?1,?2])
t1?+?1?#?等效于t1?+?torch.tensor(1)
#?tensor([1,?2,?3])
#?標量與二維向量
t2?=?torch.zeros((3,?4))
t2?+?1?#?等效于t2?+?torch.tensor(1)
#?tensor([[1.,?1.,?1.,?1.],
#?????????[1.,?1.,?1.,?1.],
#?????????[1.,?1.,?1.,?1.]])
2.2 相同維度,不同形狀張量之間的計算
我們以t2為例來探討相同維度、不同形狀的張量之間的廣播規(guī)則。
t2?=?torch.zeros(3,?4)
t2
#?tensor([[0.,?0.,?0.,?0.],
#?????????[0.,?0.,?0.,?0.],
#?????????[0.,?0.,?0.,?0.]])
t21?=?torch.ones(1,?4)
t21
#?tensor([[1.,?1.,?1.,?1.]])
它們都是二維矩陣,t21的形狀是1×4,t2的形狀是3×4,它們在第一個分量上取值不同,但該分量上t21取值為1,因此可以進行廣播計算:
t2?+?t21
#?tensor([[1.,?1.,?1.,?1.],
#????????[1.,?1.,?1.,?1.],
#????????[1.,?1.,?1.,?1.]])
而t2和t21的實際計算過程如下:
可理解為t21的一行與t2的三行分別進行了相加。而底層原理為t21的形狀由1×4拓展成了t2的3×4,然后二者對應位置進行了相加。
t22?=?torch.ones(3,?1)
t22
#?tensor([[1.],
#?????????[1.],
#?????????[1.]])
t2?+?t22
#?tensor([[1.,?1.,?1.,?1.],
#?????????[1.,?1.,?1.,?1.],
#?????????[1.,?1.,?1.,?1.]])
同理,t22+t2與t21+t2結果相同。如果矩陣的兩個維度都不相同呢?
t23?=?torch.arange(3).reshape(3,?1)
t23
#?tensor([[0],
#?????????[1],
#?????????[2]])
t24?=?torch.arange(3).reshape(1,?3)
#?tensor([[0,?1,?2]])
t23?+?t24
#?tensor([[0,?1,?2],
#?????????[1,?2,?3],
#?????????[2,?3,?4]])
此時,t23的形狀是3×1,而t24的形狀是1×3,二者的形狀在兩個份量上均不同,但都有1存在,因此可以廣播:
如果兩個張量的維度對應數不同且都不為1,那么就無法廣播。
t25?=?torch.ones(2,?4)
#?t2的shape為3×4
t2?+?t25
#?RuntimeError
高維張量的廣播
高維張量的廣播原理與低維張量的廣播原理一致:
t3?=?torch.zeros(2,?3,?4)
t3
#?tensor([[[0.,?0.,?0.,?0.],
#??????????[0.,?0.,?0.,?0.],
#??????????[0.,?0.,?0.,?0.]],
#?????????[[0.,?0.,?0.,?0.],
#?????????[0.,?0.,?0.,?0.],
#?????????[0.,?0.,?0.,?0.]]])
t31?=?torch.ones(2,?3,?1)
t31
#?tensor([[[1.],
#??????????[1.],
#??????????[1.]],
#?????????[[1.],
#??????????[1.],
#??????????[1.]]])
t3+t31
#?tensor([[[1.,?1.,?1.,?1.],
#??????????[1.,?1.,?1.,?1.],
#??????????[1.,?1.,?1.,?1.]],
#?????????[[1.,?1.,?1.,?1.],
#??????????[1.,?1.,?1.,?1.],
#??????????[1.,?1.,?1.,?1.]]])
總結
維度相同時,如果對應分量不同,但有一個為1,就可以廣播。
不同維度計算中的廣播
對于不同維度的張量,我們首先可以將低維的張量升維,然后依據相同維度不同形狀的張量廣播規(guī)則進行廣播。
低維向量的升維也非常簡單,只需將更高維度方向的形狀填充為1即可:
#?創(chuàng)建一個二維向量
t2?=?torch.arange(4).reshape(2,?2)
t2
#?tensor([[0,?1],
#?????????[2,?3]])
#?創(chuàng)建一個三維向量
t3?=?torch.zeros(3,?2,?2)
t3
t2?+?t3
#?tensor([[[0.,?1.],
#??????????[2.,?3.]],
#?????????[[0.,?1.],
#??????????[2.,?3.]],
#?????????[[0.,?1.],
#??????????[2.,?3.]]])
t3和t2的相加,就相當于1×2×2和3×2×2的兩個張量進行計算,廣播規(guī)則與低維張量一致。
相信看完本節(jié),你已經充分掌握了廣播機制的運算規(guī)則:
維度相同時,如果對應分量不同,但有一個為1,就可以廣播 維度不同時,只需將低維向量的更高維度方向的形狀填充為1即可
推薦閱讀
推薦閱讀
