快捷方式

torch.linalg.tensorsolve

torch.linalg.tensorsolve(A, B, dims=None, *, out=None) Tensor

計算系統 torch.tensordot(A, X) = B 的解 X

如果 mB.ndim 的前幾個維度的乘積,而 n 是其餘維度的乘積,則此函數預期 mn 相等。

傳回的張量 x 滿足 tensordot(A, x, dims=x.ndim) == Bx 的形狀為 A[B.ndim:]

如果指定了 dims,則 A 將被重新塑形為

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

支援 float、double、cfloat 和 cdouble 等資料類型輸入。

另請參閱

torch.linalg.tensorinv() 計算 torch.tensordot() 的乘法逆矩陣。

參數
  • A (Tensor) – 用於求解的張量。 其形狀必須滿足 prod(A.shape[:B.ndim]) == prod(A.shape[B.ndim:])

  • B (Tensor) – 形狀為 A.shape[:B.ndim] 的張量。

  • dims (Tuple[int], optional) – 要移動的 A 的維度。 如果為 None,則不移動任何維度。 預設值: None

關鍵字參數

out (Tensor, optional) – 輸出張量。 如果為 None,則忽略。 預設值: None

引發

RuntimeError – 如果上述 m 的重新塑形 A.view(m, m) 不可逆,或者前 ind 個維度的乘積不等於其餘維度的乘積。

範例

>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
>>> B = torch.randn(2 * 3, 4)
>>> X = torch.linalg.tensorsolve(A, B)
>>> X.shape
torch.Size([2, 3, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B)
True

>>> A = torch.randn(6, 4, 4, 3, 2)
>>> B = torch.randn(4, 3, 2)
>>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2))
>>> X.shape
torch.Size([6, 4])
>>> A = A.permute(1, 3, 4, 0, 2)
>>> A.shape[B.ndim:]
torch.Size([6, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6)
True

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得適合初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源