pytorch編程之廣播語義
許多 PyTorch 操作都支持NumPy Broadcasting Semantics。
簡而言之,如果 PyTorch 操作支持廣播,則其 Tensor 參數(shù)可以自動擴(kuò)展為相等大小(無需復(fù)制數(shù)據(jù))。
一般語義
如果滿足以下規(guī)則,則兩個張量是“可廣播的”:
每個張量具有至少一個維度。
從尾隨尺寸開始迭代尺寸尺寸時,尺寸尺寸必須相等,其中之一為 1,或者不存在其中之一。
例如:
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3Copy如果兩個張量x和y是“可廣播的”,則所得張量大小的計算如下:
如果
x和y的維數(shù)不相等,則在張量的維數(shù)前面加 1,以使其長度相等。然后,對于每個尺寸尺寸,所得尺寸尺寸是該尺寸上
x和y尺寸的最大值。
For Example:
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1Copy就地語義
一個復(fù)雜之處在于,就地操作不允許就地張量由于廣播而改變形狀。
For Example:
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.Copy向后兼容
只要每個張量中的元素數(shù)量相等,以前的 PyTorch 版本都可以在具有不同形狀的張量上執(zhí)行某些逐點函數(shù)。然后,通過將每個張量視為一維來執(zhí)行逐點操作。PyTorch 現(xiàn)在支持廣播,并且“一維”按點行為被認(rèn)為已棄用,并且在張量不可廣播但具有相同數(shù)量元素的情況下會生成 Python 警告。
注意,在兩個張量不具有相同形狀但可廣播且具有相同元素數(shù)量的情況下,廣播的引入會導(dǎo)致向后不兼容的更改。例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
Copy以前會產(chǎn)生一個具有大?。簍orch.Size([4,1])的張量,但現(xiàn)在會產(chǎn)生一個具有以下大?。簍orch.Size([4,4])的張量。為了幫助確定代碼中可能存在廣播引入的向后不兼容的情況,可以將 torch.utils.backcompat.broadcast_warning.enabled 設(shè)置為 True ,這將生成一個 python 在這種情況下發(fā)出警告。
For Example:
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.