torch.linalg.tensorinv¶
- torch.linalg.tensorinv(A, ind=2, *, out=None) Tensor ¶
計算
torch.tensordot()
的乘法反元素。如果 m 是
A
的前ind
個維度的乘積,而 n 是剩餘維度的乘積,則此函數預期 m 和 n 相等。 如果是這樣,它會計算一個張量 X,使得 tensordot(A
, X,ind
) 是維度 m 中的單位矩陣。 X 的形狀將與A
相同,但前ind
個維度會被推到最後X.shape == A.shape[ind:] + A.shape[:ind]
支援 float, double, cfloat 和 cdouble dtypes 的輸入。
注意
當
A
是一個 2 維張量且ind
= 1 時,此函數會計算A
的(乘法)反元素 (請參閱torch.linalg.inv()
)。注意
如果可能,考慮使用
torch.linalg.tensorsolve()
從左側用張量反元素乘以張量,因為linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim]
在可能的情況下,始終首選使用
tensorsolve()
,因為它比顯式計算偽反元素更快且數值更穩定。另請參閱
torch.linalg.tensorsolve()
計算 torch.tensordot(tensorinv(A
),B
)。- 參數
A (Tensor) – 要反轉的張量。 其形狀必須滿足 prod(
A
.shape[:ind
]) == prod(A
.shape[ind
:])。ind (int) – 用於計算
torch.tensordot()
反元素的索引。 預設值:2。
- 關鍵字參數
out (Tensor, optional) – 輸出張量。 如果為 None,則忽略。 預設值:None。
- 引發
RuntimeError – 如果重塑的
A
不可逆,或者前ind
個維度的乘積不等於剩餘維度的乘積。
範例
>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) >>> Ainv = torch.linalg.tensorinv(A, ind=2) >>> Ainv.shape torch.Size([8, 3, 4, 6]) >>> B = torch.randn(4, 6) >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) True >>> A = torch.randn(4, 4) >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) >>> Ainv = torch.linalg.inv(A) >>> torch.allclose(Atensorinv, Ainv) True