torch.linalg.solve¶
- torch.linalg.solve(A, B, *, left=True, out=None) Tensor ¶
計算具有唯一解的線性方程式平方系統的解。
令 為 或 ,此函數計算與 相關的線性系統的解 ,其定義如下
如果
left
= False,此函數會回傳矩陣 ,此矩陣為以下系統的解此線性方程式系統有唯一解的充分必要條件是 是 可逆的。此函數假設 是可逆的。
支援 float、double、cfloat 和 cdouble dtypes 的輸入。 也支援批次的矩陣,如果輸入是批次的矩陣,則輸出具有相同的批次維度。
令 * 為零或多個批次維度,
如果
A
的形狀為 (*, n, n),且B
的形狀為 (*, n) (一批向量) 或形狀為 (*, n, k) (一批矩陣或「多個右手邊」),則此函數會傳回形狀分別為 (*, n) 或 (*, n, k) 的 X。否則,如果
A
的形狀為 (*, n, n),且B
的形狀為 (n,) 或 (n, k),則B
會被廣播為具有形狀 (*, n) 或 (*, n, k)。 然後,此函數會傳回產生的批次線性方程式系統的解。
注意
此函數以比單獨執行計算更快且數值上更穩定的方式計算 X =
A
.inverse() @B
。注意
可以藉由傳遞轉置的輸入
A
和B
並轉置此函數傳回的輸出,來計算系統 的解。注意
允許
A
為非批次的 torch.sparse_csr_tensor,但僅限於 left=True。注意
當輸入位於 CUDA 裝置上時,此函數會將該裝置與 CPU 同步。 如需不同步的版本,請參閱
torch.linalg.solve_ex()
。另請參閱
torch.linalg.solve_triangular()
計算具有唯一解的三角線性方程式系統的解。- 參數
- 關鍵字引數
- 引發
RuntimeError – 如果
A
矩陣不可逆,或批次A
中的任何矩陣不可逆。
範例
>>> A = torch.randn(3, 3) >>> b = torch.randn(3) >>> x = torch.linalg.solve(A, b) >>> torch.allclose(A @ x, b) True >>> A = torch.randn(2, 3, 3) >>> B = torch.randn(2, 3, 4) >>> X = torch.linalg.solve(A, B) >>> X.shape torch.Size([2, 3, 4]) >>> torch.allclose(A @ X, B) True >>> A = torch.randn(2, 3, 3) >>> b = torch.randn(3, 1) >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) >>> x.shape torch.Size([2, 3, 1]) >>> torch.allclose(A @ x, b) True >>> b = torch.randn(3) >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) >>> x.shape torch.Size([2, 3]) >>> Ax = A @ x.unsqueeze(-1) >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) True