快捷方式

torch.lu

torch.lu(*args, **kwargs)[原始碼]

計算矩陣或批次矩陣 A 的 LU 分解。傳回包含 LU 分解和 A 樞軸的元組。如果 pivot 設定為 True,則執行樞軸操作。

警告

建議使用 torch.linalg.lu_factor()torch.linalg.lu_factor_ex(),而不是使用已棄用的 torch.lu()torch.lu() 將會在未來的 PyTorch 版本中移除。 LU, pivots, info = torch.lu(A, compute_pivots) 應該替換為

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 應該替換為

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 批次中每個矩陣的回傳排列矩陣由大小為 min(A.shape[-2], A.shape[-1]) 的 1-indexed 向量表示。 pivots[i] == j 表示在演算法的第 i 個步驟中,第 i 列與第 j-1 列進行了置換。

  • pivot = False 的 LU 分解不適用於 CPU,嘗試這樣做會引發錯誤。但是,pivot = False 的 LU 分解適用於 CUDA。

  • 如果 get_infosTrue,則此函式不會檢查分解是否成功,因為分解的狀態存在於回傳 tuple 的第三個元素中。

  • 在 CUDA 裝置上大小小於等於 32 的方形矩陣批次中,由於 MAGMA 函式庫中的錯誤(請參閱 magma issue 13),LU 分解會針對奇異矩陣重複進行。

  • 可以使用 torch.lu_unpack() 導出 LUP

警告

只有當 A 為滿秩時,此函式的梯度才會是有限的。這是因為 LU 分解僅在滿秩矩陣上是可微分的。此外,如果 A 接近非滿秩,則梯度在數值上將是不穩定的,因為它取決於 L1L^{-1}U1U^{-1} 的計算。

參數
  • A (Tensor) – 要分解的 tensor,大小為 (,m,n)(*, m, n)

  • pivot (bool, optional) – 控制是否進行 pivoting。預設值:True

  • get_infos (bool, optional) – 如果設定為 True,則傳回 info IntTensor。預設值:False

  • out (tuple, optional) – 可選的輸出 tuple。如果 get_infosTrue,則 tuple 中的元素為 Tensor、IntTensor 和 IntTensor。如果 get_infosFalse,則 tuple 中的元素為 Tensor 和 IntTensor。預設值:None

傳回

包含以下內容的 tensor tuple

  • factorization (Tensor): 大小為 (,m,n)(*, m, n) 的分解

  • pivots (IntTensor):大小為 (,min(m,n))(*, \text{min}(m, n)) 的軸 (pivots)。pivots 儲存所有列的中間換位 (intermediate transpositions)。最終的排列 perm 可以透過對 perm[i]perm[pivots[i] - 1]) 應用 swap(perm[i], perm[pivots[i] - 1]) 來重建,其中 i = 0, ..., pivots.size(-1) - 1,且 perm 最初是 mm 個元素的恆等排列(本質上,這就是 torch.lu_unpack() 所做的事情)。

  • infos (IntTensor, optional):如果 get_infosTrue,則這是一個大小為 ()(*) 的張量,其中非零值表示矩陣或每個小批次的因式分解是否成功或失敗。

回傳類型

(Tensor, IntTensor, IntTensor (optional))

範例

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源