從零開始深度學(xué)習(xí)Pytorch筆記(4)——張量的拼接與切分


前文傳送門:
從零開始深度學(xué)習(xí)Pytorch筆記(1)——安裝Pytorch
從零開始深度學(xué)習(xí)Pytorch筆記(2)——張量的創(chuàng)建(上)
從零開始深度學(xué)習(xí)Pytorch筆記(3)——張量的創(chuàng)建(下)在該系列的上一篇:從零開始深度學(xué)習(xí)Pytorch筆記(3)——張量的創(chuàng)建(下),我們介紹了更多Pytorch中的張量創(chuàng)建方式,本文研究張量的拼接與切分。
張量的拼接
import?torch
(1) 使用torch.cat()拼接
將張量按維度dim進(jìn)行拼接,不會擴(kuò)張張量的維度
torch.cat(tensors,?dim=0,?out=None)
其中:
tensors:張量序列
dim:要拼接的維度
t?=?torch.ones((3,2))
t0?=?torch.cat([t,t],dim=0)#在第0個維度上拼接
t1?=?torch.cat([t,t],dim=1)#在第1個維度上拼接
print(t0,'\n\n',t1)

t2?=?torch.cat([t,t,t],dim=0)
t2

(2) 使用torch.stack()拼接
在新創(chuàng)建的維度dim上進(jìn)行拼接,會擴(kuò)張張量的維度
torch.stack(tensors,?dim=0,?out=None)
參數(shù):
tensors:張量序列
dim:要拼接的維度
t?=?torch.ones((3,2))
t1?=?torch.stack([t,t],dim=2)#在新創(chuàng)建的維度上進(jìn)行拼接
print(t1,t1.shape)?#拼接完會從2維變成3維

我們可以看到維度從拼接前的(3,2)變成了(3,2,2),即在最后的維度上進(jìn)行了拼接!
t?=?torch.ones((3,2))
t1?=?torch.stack([t,t],dim=0)#在新創(chuàng)建的維度上進(jìn)行拼接
#由于指定是第0維,會把原來的3,2往后移動一格,然后在新的第0維創(chuàng)建新維度進(jìn)行拼接
print(t1,t1.shape)

t?=?torch.ones((3,2))
t1?=?torch.stack([t,t,t],dim=0)#在新創(chuàng)建的維度上進(jìn)行拼接
#由于是第0維,會把原來的3,2往后移動一格,然后在新的第0維創(chuàng)建新維度進(jìn)行拼接
print(t1,t1.shape)

張量的切分
(1) 使用torch.chunk()切分
可以將張量按維度dim進(jìn)行平均切分
return 張量列表
如果不能整除,最后一份張量小于其他張量
torch.chunk(input,?chunks,?dim=0)
參數(shù):
input:要切分的張量
chunks:要切分的份數(shù)
dim:要切分的維度
a?=?torch.ones((5,2))
t?=?torch.chunk(a,dim=0,chunks=2)#在5這個維度切分,切分成2個張量
for?idx,?t_chunk?in?enumerate(t):
????print(idx,t_chunk,t_chunk.shape)

可以看出后一個張量小于前一個張量的,前者第0個維度上是3,后者是2。
(2) 使用torch.split()切分
將張量按維度dim進(jìn)行切分
return:張量列表
torch.split(tensor,?split_size_or_sections,?dim=0)
參數(shù):
tensor:要切分的張量
split_size_or_sections:為int時,表示每一份的長度;為list時,按list元素切分
dim:要切分的維度
a?=?torch.ones((5,2))
t?=?torch.split(a,2,dim=0)#指定了每個張量的長度為2
for?idx,?t_split?in?enumerate(t):
????print(idx,t_split,t_split.shape)#切出3個張量

a?=?torch.ones((5,2))
t?=?torch.split(a,[2,1,2],dim=0)#指定了每個張量的長度為列表中的大小【2,1,2】
for?idx,?t_split?in?enumerate(t):
????print(idx,t_split,t_split.shape)#切出3個張量

a?=?torch.ones((5,2))
t?=?torch.split(a,[2,1,1],dim=0)#list中求和不為長度則拋出異常
for?idx,?t_split?in?enumerate(t):
????print(idx,t_split,t_split.shape)#切出3個張量
RuntimeError:split_with_sizes expects split_sizes to sum exactly to 5 (input tensor's size at dimension 0), but got split_sizes=[2, 1, 1]
歡迎關(guān)注公眾號學(xué)習(xí)之后的連載部分~
你點(diǎn)的每個在看,我都認(rèn)真當(dāng)成了喜歡