torch.linalg.lstsq¶
- torch.linalg.lstsq(A, B, rcond=None, *, driver=None)¶
計算線性方程式系統的最小平方問題的解。
令 為 或 ,對於線性系統 ,其中 ,最小平方法問題 定義為
其中 表示 Frobenius 範數。
支援 float、double、cfloat 和 cdouble 等 dtype 的輸入。 也支援批次的矩陣,如果輸入是批次的矩陣,則輸出具有相同的批次維度。
driver
選擇將使用的後端函式。 對於 CPU 輸入,有效值為 ‘gels’、‘gelsy’、‘gelsd、‘gelss’。 若要選擇 CPU 上最佳的驅動程式,請考慮如果
A
是良置的(其條件數不太大),或者您不介意一些精度損失。對於一般矩陣:‘gelsy’(具有樞軸的 QR)(預設)
如果
A
是滿秩的:‘gels’(QR)
如果
A
不是良置的。‘gelsd’(三對角線簡化和 SVD)
但如果您遇到記憶體問題:‘gelss’(完整 SVD)。
對於 CUDA 輸入,唯一有效的驅動程式是 ‘gels’,它假設
A
是滿秩的。另請參閱 這些驅動程式的完整說明
rcond
用於決定A
中矩陣的有效秩,當driver
為 (‘gelsy’, ‘gelsd’, ‘gelss’) 其中之一時。 在這種情況下,如果 是 A 的奇異值,並以降序排列,則如果 , 會被捨入為零。 如果rcond
= None (預設值),則rcond
會設定為A
的 dtype 的機器精度乘以 max(m, n)。此函數會回傳問題的解和一些額外資訊,以包含四個張量的具名元組 (solution, residuals, rank, singular_values) 形式呈現。 對於形狀分別為 (*, m, n)、(*, m, k) 的輸入
A
和B
,此元組包含:solution:最小平方解。 其形狀為 (*, n, k)。
residuals:解的平方殘差,也就是 。 其形狀等於
A
的批次維度。 當 m > n 且A
中的每個矩陣都是滿秩矩陣時,才會計算此殘差,否則它是一個空的張量。 如果A
是一批矩陣,且批次中的任何矩陣都不是滿秩矩陣,則會回傳一個空的張量。 此行為可能會在未來的 PyTorch 版本中變更。rank:
A
中矩陣的秩的張量。 其形狀等於A
的批次維度。 當driver
為 (‘gelsy’, ‘gelsd’, ‘gelss’) 其中之一時,才會計算此秩,否則它是一個空的張量。singular_values:
A
中矩陣的奇異值的張量。 其形狀為 (*, min(m, n))。 當driver
為 (‘gelsd’, ‘gelss’) 其中之一時,才會計算此秩,否則它是一個空的張量。
注意
此函數會以比單獨執行計算更快且數值上更穩定的方式計算 X =
A
.pinverse() @B
。警告
rcond
的預設值可能會在未來的 PyTorch 版本中變更。 因此,建議使用固定值以避免潛在的重大變更。- 參數
- 關鍵字參數 (Keyword Arguments)
driver (str, optional) – 要使用的 LAPACK/MAGMA 方法名稱。如果 None,CPU 輸入使用 ‘gelsy’,CUDA 輸入使用 ‘gels’。預設值:None。
- 回傳值 (Returns)
一個具名元組 (named tuple) (solution, residuals, rank, singular_values)。
範例 (Examples)
>>> A = torch.randn(1,3,3) >>> A tensor([[[-1.0838, 0.0225, 0.2275], [ 0.2438, 0.3844, 0.5499], [ 0.1175, -0.9102, 2.0870]]]) >>> B = torch.randn(2,3,3) >>> B tensor([[[-0.6772, 0.7758, 0.5109], [-1.4382, 1.3769, 1.1818], [-0.3450, 0.0806, 0.3967]], [[-1.3994, -0.1521, -0.1473], [ 1.9194, 1.0458, 0.6705], [-1.1802, -0.9796, 1.4086]]]) >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) >>> torch.dist(X, torch.linalg.pinv(A) @ B) tensor(1.5152e-06) >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values >>> torch.dist(S, torch.linalg.svdvals(A)) tensor(2.3842e-07) >>> A[:, 0].zero_() # Decrease the rank of A >>> rank = torch.linalg.lstsq(A, B).rank >>> rank tensor([2])