廣播語意¶
許多 PyTorch 運算支援 NumPy 的廣播語意。請參閱 https://numpy.dev.org.tw/doc/stable/user/basics.broadcasting.html 以取得詳細資訊。
簡而言之,如果 PyTorch 運算支援廣播,則其 Tensor 引數可以自動擴展為相等大小(無需複製資料)。
一般語意¶
如果符合以下規則,則兩個張量是「可廣播的」
每個張量都至少有一個維度。
當迭代維度大小時,從尾隨維度開始,維度大小必須相等,其中一個為 1,或者其中一個不存在。
例如
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
如果兩個張量 x
和 y
是「可廣播的」,則結果張量大小的計算方式如下
如果
x
和y
的維度數量不相等,則將 1 前置到維度較少的張量的維度,使它們的長度相等。然後,對於每個維度大小,結果維度大小是
x
和y
沿該維度的大小的最大值。
例如
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
原地語意¶
一個複雜之處在於,原地 (in-place) 操作不允許原地張量 (in-place tensor) 因廣播 (broadcast) 而改變形狀。
例如
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
向後相容性¶
先前版本的 PyTorch 允許某些逐點 (pointwise) 函式在具有不同形狀的張量上執行,只要每個張量中的元素數量相等即可。 然後,逐點操作將通過將每個張量視為 1 維來執行。 PyTorch 現在支持廣播,並且“1 維”逐點行為被認為已棄用,並且在張量不可廣播但具有相同數量的元素的情況下,將產生 Python 警告。
請注意,在兩個張量形狀不同但可以廣播且具有相同數量的元素的情況下,引入廣播可能會導致向後不相容的變更。 例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
以前會產生大小為:torch.Size([4,1]) 的張量,但現在會產生大小為:torch.Size([4,4]) 的張量。 為了幫助識別程式碼中可能存在的由廣播引入的向後不相容的情況,您可以將 torch.utils.backcompat.broadcast_warning.enabled 設置為 True,這將在這種情況下產生一個 python 警告。
例如
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.