捷徑

torch.linalg.lstsq

torch.linalg.lstsq(A, B, rcond=None, *, driver=None)

計算線性方程式系統的最小平方問題的解。

K\mathbb{K}R\mathbb{R}C\mathbb{C},對於線性系統 AX=BAX = B,其中 AKm×n,BKm×kA \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}最小平方法問題 定義為

minXKn×kAXBF\min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F

其中 F\|-\|_F 表示 Frobenius 範數。

支援 float、double、cfloat 和 cdouble 等 dtype 的輸入。 也支援批次的矩陣,如果輸入是批次的矩陣,則輸出具有相同的批次維度。

driver 選擇將使用的後端函式。 對於 CPU 輸入,有效值為 ‘gels’‘gelsy’‘gelsd‘gelss’。 若要選擇 CPU 上最佳的驅動程式,請考慮

  • 如果 A 是良置的(其條件數不太大),或者您不介意一些精度損失。

    • 對於一般矩陣:‘gelsy’(具有樞軸的 QR)(預設)

    • 如果 A 是滿秩的:‘gels’(QR)

  • 如果 A 不是良置的。

    • ‘gelsd’(三對角線簡化和 SVD)

    • 但如果您遇到記憶體問題:‘gelss’(完整 SVD)。

對於 CUDA 輸入,唯一有效的驅動程式是 ‘gels’,它假設 A 是滿秩的。

另請參閱 這些驅動程式的完整說明

rcond 用於決定 A 中矩陣的有效秩,當 driver 為 (‘gelsy’, ‘gelsd’, ‘gelss’) 其中之一時。 在這種情況下,如果 σi\sigma_iA 的奇異值,並以降序排列,則如果 σircondσ1\sigma_i \leq \text{rcond} \cdot \sigma_1σi\sigma_i 會被捨入為零。 如果 rcond= None (預設值),則 rcond 會設定為 A 的 dtype 的機器精度乘以 max(m, n)

此函數會回傳問題的解和一些額外資訊,以包含四個張量的具名元組 (solution, residuals, rank, singular_values) 形式呈現。 對於形狀分別為 (*, m, n)(*, m, k) 的輸入 AB,此元組包含:

  • solution:最小平方解。 其形狀為 (*, n, k)

  • residuals:解的平方殘差,也就是 AXBF2\|AX - B\|_F^2。 其形狀等於 A 的批次維度。 當 m > nA 中的每個矩陣都是滿秩矩陣時,才會計算此殘差,否則它是一個空的張量。 如果 A 是一批矩陣,且批次中的任何矩陣都不是滿秩矩陣,則會回傳一個空的張量。 此行為可能會在未來的 PyTorch 版本中變更。

  • rankA 中矩陣的秩的張量。 其形狀等於 A 的批次維度。 當 driver 為 (‘gelsy’, ‘gelsd’, ‘gelss’) 其中之一時,才會計算此秩,否則它是一個空的張量。

  • singular_valuesA 中矩陣的奇異值的張量。 其形狀為 (*, min(m, n))。 當 driver 為 (‘gelsd’, ‘gelss’) 其中之一時,才會計算此秩,否則它是一個空的張量。

注意

此函數會以比單獨執行計算更快且數值上更穩定的方式計算 X = A.pinverse() @ B

警告

rcond 的預設值可能會在未來的 PyTorch 版本中變更。 因此,建議使用固定值以避免潛在的重大變更。

參數
  • A (Tensor) – 形狀為 (*, m, n) 的 lhs 張量,其中 * 是零個或多個批次維度。

  • B (Tensor) – 形狀為 (*, m, k) 的 rhs 張量,其中 * 是零個或多個批次維度。

  • rcond (float, optional) – 用於決定 A 的有效秩 (effective rank)。如果 rcond= Nonercond 會被設定為 A 的 dtype 的機器精度 (machine precision) 乘以 max(m, n)。預設值:None

關鍵字參數 (Keyword Arguments)

driver (str, optional) – 要使用的 LAPACK/MAGMA 方法名稱。如果 None,CPU 輸入使用 ‘gelsy’,CUDA 輸入使用 ‘gels’。預設值:None

回傳值 (Returns)

一個具名元組 (named tuple) (solution, residuals, rank, singular_values)

範例 (Examples)

>>> A = torch.randn(1,3,3)
>>> A
tensor([[[-1.0838,  0.0225,  0.2275],
     [ 0.2438,  0.3844,  0.5499],
     [ 0.1175, -0.9102,  2.0870]]])
>>> B = torch.randn(2,3,3)
>>> B
tensor([[[-0.6772,  0.7758,  0.5109],
     [-1.4382,  1.3769,  1.1818],
     [-0.3450,  0.0806,  0.3967]],
    [[-1.3994, -0.1521, -0.1473],
     [ 1.9194,  1.0458,  0.6705],
     [-1.1802, -0.9796,  1.4086]]])
>>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3)
>>> torch.dist(X, torch.linalg.pinv(A) @ B)
tensor(1.5152e-06)

>>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values
>>> torch.dist(S, torch.linalg.svdvals(A))
tensor(2.3842e-07)

>>> A[:, 0].zero_()  # Decrease the rank of A
>>> rank = torch.linalg.lstsq(A, B).rank
>>> rank
tensor([2])

文件

存取 PyTorch 的完整開發者文件

檢視文件 (View Docs)

教學 (Tutorials)

取得適用於初學者和進階開發人員的深入教學課程

檢視教學 (View Tutorials)

資源 (Resources)

尋找開發資源並獲得您的問題解答

檢視資源 (View Resources)