快捷方式

torch.linalg.pinv

torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) Tensor

計算矩陣的偽逆矩陣(摩爾-彭若斯廣義逆矩陣)。

偽逆矩陣可以用代數方式定義,但透過SVD(奇異值分解)來理解它在計算上更為方便。

支援 float、double、cfloat 和 cdouble 等資料型態的輸入。也支援批次矩陣,如果 A 是一批矩陣,則輸出具有相同的批次維度。

如果 hermitian= True,則假定 A 在複數情況下為 Hermitian 矩陣,或在實數情況下為對稱矩陣,但內部不會進行檢查。相反,僅使用矩陣的下三角部分進行計算。

小於閾值 max(atol,σ1rtol)\max(\text{atol}, \sigma_1 \cdot \text{rtol}) 的奇異值(或當 hermitian= True 時,特徵值的範數)在計算中會被視為零並捨棄,其中 σ1\sigma_1 是最大的奇異值(或特徵值)。

如果未指定 rtolA 是維度為 (m, n) 的矩陣,則相對容差設定為 rtol=max(m,n)ε\text{rtol} = \max(m, n) \varepsilon,並且 ε\varepsilonA 的資料型態的 epsilon 值(請參閱 finfo)。如果未指定 rtol 並且指定 atol 大於零,則 rtol 設定為零。

如果 atolrtol 是一個 torch.Tensor,則其形狀必須可廣播到 torch.linalg.svd() 傳回的 A 的奇異值形狀。

注意

如果 hermitian= False,此函數使用 torch.linalg.svd();如果 hermitian= True,則使用 torch.linalg.eigh()。對於 CUDA 輸入,此函數會將該裝置與 CPU 同步。

注意

如果可能,考慮使用 torch.linalg.lstsq() 將矩陣從左側乘以偽逆矩陣,因為

torch.linalg.lstsq(A, B).solution == A.pinv() @ B

在可能的情況下,始終建議使用 lstsq(),因為它比顯式計算偽逆矩陣更快且數值更穩定。

注意

此函數具有與 NumPy 相容的變體 linalg.pinv(A, rcond, hermitian=False)。但是,使用位置引數 rcond 已被棄用,建議改用 rtol

警告

此函數內部使用 torch.linalg.svd()(或當 hermitian= True 時使用 torch.linalg.eigh()),因此其導數與這些函數存在相同的問題。有關更多詳細資訊,請參閱 torch.linalg.svd()torch.linalg.eigh() 中的警告。

另請參閱

torch.linalg.inv() 計算方陣的反矩陣。

torch.linalg.lstsq() 使用數值穩定的演算法計算 A.pinv() @ B

參數
  • A (Tensor) – 形狀為 (*, m, n) 的張量,其中 * 為零或多個批次維度。

  • rcond (float, Tensor, optional) – [NumPy 相容性]. rtol 的別名。預設值:None

關鍵字參數
  • atol (float, Tensor, optional) – 絕對容忍值。 當 None 時,視為零。預設值:None

  • rtol (float, Tensor, optional) – 相對容忍值。 有關當 None 時所取的值,請參閱上文。預設值:None

  • hermitian (bool, optional) – 指示 A 在複數情況下是否為 Hermitian,在實數情況下是否為對稱。預設值:False

  • out (Tensor, optional) – 輸出張量。 如果 None,則忽略。預設值:None

範例

>>> A = torch.randn(3, 5)
>>> A
tensor([[ 0.5495,  0.0979, -1.4092, -0.1128,  0.4132],
        [-1.1143, -0.3662,  0.3042,  1.6374, -0.9294],
        [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
>>> torch.linalg.pinv(A)
tensor([[ 0.0600, -0.1933, -0.2090],
        [-0.0903, -0.0817, -0.4752],
        [-0.7124, -0.1631, -0.2272],
        [ 0.1356,  0.3933, -0.5023],
        [-0.0308, -0.1725, -0.5216]])

>>> A = torch.randn(2, 6, 3)
>>> Apinv = torch.linalg.pinv(A)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(8.5633e-07)

>>> A = torch.randn(3, 3, dtype=torch.complex64)
>>> A = A + A.T.conj()  # creates a Hermitian matrix
>>> Apinv = torch.linalg.pinv(A, hermitian=True)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(1.0830e-06)

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源