torch.triangular_solve¶
- torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)¶
求解具有方陣上三角或下三角可逆矩陣 和多個右側向量 的方程組。
用符號表示,它求解 並假設 是方陣上三角矩陣 (如果
upper
= False 則是下三角矩陣),且對角線上沒有零。torch.triangular_solve(b, A) 可以接收 2D 輸入 b, A 或 2D 矩陣的批次輸入。 如果輸入是批次,則回傳批次輸出 X
如果
A
的對角線包含零或非常接近零的元素,且unitriangular
= False (預設值),或者如果輸入矩陣的條件很差,則結果可能包含 NaN。支援 float、double、cfloat 和 cdouble 資料類型的輸入。
警告
torch.triangular_solve()
已被棄用,建議使用torch.linalg.solve_triangular()
,並將在未來的 PyTorch 版本中移除。torch.linalg.solve_triangular()
的參數順序已反轉,且不會回傳其中一個輸入的副本。X = torch.triangular_solve(B, A).solution
應該替換為X = torch.linalg.solve_triangular(A, B)
- 參數
b (Tensor) – 大小為 的多個右側向量,其中 是零個或多個批次維度
A (Tensor) – 大小為 的輸入三角係數矩陣,其中 是零個或多個批次維度
upper (bool, optional) – 是上三角還是下三角。 預設值:
True
。transpose (bool, optional) – 求解 op(A)X = b,其中如果此標誌為
True
則 op(A) = A^T,如果為False
則 op(A) = A。 預設值:False
。unitriangular (bool, optional) – 是否 為單位三角矩陣。若為 True,則 的對角線元素會被假設為 1,且不會從 參考。預設值:
False
。
- 關鍵字參數
out ((Tensor, Tensor), optional) – 用於寫入輸出的兩個張量元組。若為 None 則忽略。預設值:None。
- 回傳值
一個名為 (solution, cloned_coefficient) 的 namedtuple,其中 cloned_coefficient 是 的副本,而 solution 則是 的解 (或系統方程式的任何變體,取決於關鍵字參數)。
範例
>>> A = torch.randn(2, 2).triu() >>> A tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) >>> b = torch.randn(2, 3) >>> b tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) >>> torch.triangular_solve(b, A) torch.return_types.triangular_solve( solution=tensor([[ 1.7841, 2.9046, -2.5405], [ 1.9320, 0.9270, -1.2826]]), cloned_coefficient=tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]))