快速鍵

torch.linalg.tensorinv

torch.linalg.tensorinv(A, ind=2, *, out=None) Tensor

計算 torch.tensordot() 的乘法反元素。

如果 mA 的前 ind 個維度的乘積,而 n 是剩餘維度的乘積,則此函數預期 mn 相等。 如果是這樣,它會計算一個張量 X,使得 tensordot(A, X, ind) 是維度 m 中的單位矩陣。 X 的形狀將與 A 相同,但前 ind 個維度會被推到最後

X.shape == A.shape[ind:] + A.shape[:ind]

支援 float, double, cfloat 和 cdouble dtypes 的輸入。

注意

A 是一個 2 維張量且 ind= 1 時,此函數會計算 A 的(乘法)反元素 (請參閱 torch.linalg.inv())。

注意

如果可能,考慮使用 torch.linalg.tensorsolve() 從左側用張量反元素乘以張量,因為

linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B)  # When B is a tensor with shape A.shape[:B.ndim]

在可能的情況下,始終首選使用 tensorsolve(),因為它比顯式計算偽反元素更快且數值更穩定。

另請參閱

torch.linalg.tensorsolve() 計算 torch.tensordot(tensorinv(A), B)

參數
  • A (Tensor) – 要反轉的張量。 其形狀必須滿足 prod(A.shape[:ind]) == prod(A.shape[ind:])

  • ind (int) – 用於計算 torch.tensordot() 反元素的索引。 預設值:2

關鍵字參數

out (Tensor, optional) – 輸出張量。 如果為 None,則忽略。 預設值:None

引發

RuntimeError – 如果重塑的 A 不可逆,或者前 ind 個維度的乘積不等於剩餘維度的乘積。

範例

>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3))
>>> Ainv = torch.linalg.tensorinv(A, ind=2)
>>> Ainv.shape
torch.Size([8, 3, 4, 6])
>>> B = torch.randn(4, 6)
>>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B))
True

>>> A = torch.randn(4, 4)
>>> Atensorinv = torch.linalg.tensorinv(A, ind=1)
>>> Ainv = torch.linalg.inv(A)
>>> torch.allclose(Atensorinv, Ainv)
True

文件

存取 PyTorch 的綜合開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源