torch.linalg.tensorsolve¶
- torch.linalg.tensorsolve(A, B, dims=None, *, out=None) Tensor ¶
計算系統 torch.tensordot(A, X) = B 的解 X。
如果 m 是
B
.ndim 的前幾個維度的乘積,而 n 是其餘維度的乘積,則此函數預期 m 和 n 相等。傳回的張量 x 滿足 tensordot(
A
, x, dims=x.ndim) ==B
。 x 的形狀為A
[B.ndim:]。如果指定了
dims
,則A
將被重新塑形為A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))
支援 float、double、cfloat 和 cdouble 等資料類型輸入。
另請參閱
torch.linalg.tensorinv()
計算torch.tensordot()
的乘法逆矩陣。- 參數
- 關鍵字參數
out (Tensor, optional) – 輸出張量。 如果為 None,則忽略。 預設值: None。
- 引發
RuntimeError – 如果上述 m 的重新塑形
A
.view(m, m) 不可逆,或者前ind
個維度的乘積不等於其餘維度的乘積。
範例
>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) >>> B = torch.randn(2 * 3, 4) >>> X = torch.linalg.tensorsolve(A, B) >>> X.shape torch.Size([2, 3, 4]) >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) True >>> A = torch.randn(6, 4, 4, 3, 2) >>> B = torch.randn(4, 3, 2) >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) >>> X.shape torch.Size([6, 4]) >>> A = A.permute(1, 3, 4, 0, 2) >>> A.shape[B.ndim:] torch.Size([6, 4]) >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) True