torch.einsum¶
- torch.einsum(equation, *operands) Tensor [source][source]¶
根據愛因斯坦求和約定,使用符號指定維度,對輸入
operands
的元素乘積求和。Einsum 允許透過基於愛因斯坦求和約定的簡寫格式來計算許多常見的多維線性代數陣列運算,該格式由
equation
指定。此格式的詳細資訊如下所述,但其總體概念是使用某些下標標記輸入operands
的每個維度,並定義哪些下標是輸出的一部分。然後,透過對operands
的元素乘積沿著下標非輸出部分的維度求和來計算輸出。例如,可以使用 einsum 計算矩陣乘法,如 torch.einsum(“ij,jk->ik”, A, B) 所示。在此,j 是求和下標,i 和 k 是輸出下標(有關更多詳細資訊,請參閱下面的章節)。方程式 (Equation)
equation
字串指定輸入operands
每個維度的下標([a-zA-Z] 中的字母),順序與維度相同,並以逗號(‘,’)分隔每個運算元的下標,例如 ‘ij,jk’ 指定兩個 2D 運算元的下標。標記有相同下標的維度必須是可廣播的,也就是說,它們的大小必須匹配或為 1。例外情況是,如果同一個輸入運算元的下標重複出現,在這種情況下,此運算元標記有此下標的維度的大小必須匹配,並且運算元將被其沿這些維度的對角線替換。在equation
中只出現一次的下標將是輸出的一部分,並按字母順序遞增排序。輸出的計算方式是逐元素地乘以輸入operands
,其維度根據下標對齊,然後對下標非輸出部分的維度進行求和。或者,可以透過在方程式末尾添加箭頭(‘->’),後跟輸出下標來顯式定義輸出下標。例如,以下方程式計算矩陣乘法的轉置:‘ij,jk->ki’。輸出下標對於某些輸入運算元必須至少出現一次,對於輸出最多出現一次。
省略符號(‘…’)可以用於代替下標,以廣播省略符號所涵蓋的維度。每個輸入運算元最多可以包含一個省略符號,該省略符號將涵蓋未被下標涵蓋的維度,例如,對於具有 5 個維度的輸入運算元,方程式 ‘ab…c’ 中的省略符號涵蓋第三個和第四個維度。省略符號不需要跨
operands
涵蓋相同數量的維度,但是省略符號的“形狀”(它們涵蓋的維度的大小)必須一起廣播。如果未使用箭頭(‘->’)表示法顯式定義輸出,則省略符號將首先出現在輸出中(最左邊的維度),然後是在輸入運算元中僅出現一次的下標標籤。例如,以下方程式實現了批次矩陣乘法 ‘…ij,…jk’。最後一些注意事項:方程式可能包含不同元素(下標、省略符號、箭頭和逗號)之間的空格,但類似 ‘…’ 的內容無效。空字串 ‘’ 對於純量運算元是有效的。
注意
torch.einsum
處理省略符號(‘…’)的方式與 NumPy 不同,它允許對省略符號涵蓋的維度進行求和,也就是說,省略符號不需要是輸出的一部分。注意
請安裝 opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) 以取得更高效能的 einsum。您可以透過以下方式安裝:pip install torch[opt-einsum] 或單獨安裝:pip install opt-einsum。
如果 opt-einsum 可用,此函數將透過我們的 opt_einsum 後端
torch.backends.opt_einsum
優化收縮順序來自動加速計算和/或消耗更少的記憶體(_ 和 - 之間令人困惑,我知道)。當至少有三個輸入時,會發生此優化,因為否則順序無關緊要。請注意,尋找最佳路徑是 NP-hard 問題,因此,opt-einsum 依賴於不同的啟發式方法來實現接近最佳的結果。如果 opt-einsum 不可用,則預設順序是從左到右收縮。要繞過此預設行為,請新增以下內容以停用 opt_einsum 並跳過路徑計算:
torch.backends.opt_einsum.enabled = False
要指定您希望 opt_einsum 用於計算收縮路徑的策略,請新增以下行:
torch.backends.opt_einsum.strategy = 'auto'
。預設策略是 ‘auto’,我們也支援 ‘greedy’ 和 ‘optimal’。聲明一下,‘optimal’ 的執行時間是輸入數量的階乘!有關更多詳細資訊,請參閱 opt_einsum 文件 (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。注意
從 PyTorch 1.10 開始,
torch.einsum()
也支援子列表格式(請參閱下面的範例)。在此格式中,每個運算元的下標由子列表指定,子列表是 [0, 52) 範圍內的整數列表。這些子列表跟隨它們的運算元,並且可以在輸入的末尾添加一個額外的子列表來指定輸出的下標,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。可以在子列表中提供 Python 的 Ellipsis 物件,以啟用如上述方程式章節中所述的廣播。範例
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])