torch.lu¶
- torch.lu(*args, **kwargs)[原始碼]¶
計算矩陣或批次矩陣
A
的 LU 分解。傳回包含 LU 分解和A
樞軸的元組。如果pivot
設定為True
,則執行樞軸操作。警告
建議使用
torch.linalg.lu_factor()
和torch.linalg.lu_factor_ex()
,而不是使用已棄用的torch.lu()
。torch.lu()
將會在未來的 PyTorch 版本中移除。LU, pivots, info = torch.lu(A, compute_pivots)
應該替換為LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
應該替換為LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
注意
批次中每個矩陣的回傳排列矩陣由大小為
min(A.shape[-2], A.shape[-1])
的 1-indexed 向量表示。pivots[i] == j
表示在演算法的第i
個步驟中,第i
列與第j-1
列進行了置換。pivot
=False
的 LU 分解不適用於 CPU,嘗試這樣做會引發錯誤。但是,pivot
=False
的 LU 分解適用於 CUDA。如果
get_infos
為True
,則此函式不會檢查分解是否成功,因為分解的狀態存在於回傳 tuple 的第三個元素中。在 CUDA 裝置上大小小於等於 32 的方形矩陣批次中,由於 MAGMA 函式庫中的錯誤(請參閱 magma issue 13),LU 分解會針對奇異矩陣重複進行。
可以使用
torch.lu_unpack()
導出L
、U
和P
。
警告
只有當
A
為滿秩時,此函式的梯度才會是有限的。這是因為 LU 分解僅在滿秩矩陣上是可微分的。此外,如果A
接近非滿秩,則梯度在數值上將是不穩定的,因為它取決於 和 的計算。- 參數
A (Tensor) – 要分解的 tensor,大小為
pivot (bool, optional) – 控制是否進行 pivoting。預設值:
True
get_infos (bool, optional) – 如果設定為
True
,則傳回 info IntTensor。預設值:False
out (tuple, optional) – 可選的輸出 tuple。如果
get_infos
為True
,則 tuple 中的元素為 Tensor、IntTensor 和 IntTensor。如果get_infos
為False
,則 tuple 中的元素為 Tensor 和 IntTensor。預設值:None
- 傳回
包含以下內容的 tensor tuple
factorization (Tensor): 大小為 的分解
pivots (IntTensor):大小為 的軸 (pivots)。
pivots
儲存所有列的中間換位 (intermediate transpositions)。最終的排列perm
可以透過對perm[i]
和perm[pivots[i] - 1])
應用swap(perm[i], perm[pivots[i] - 1])
來重建,其中i = 0, ..., pivots.size(-1) - 1
,且perm
最初是 個元素的恆等排列(本質上,這就是torch.lu_unpack()
所做的事情)。infos (IntTensor, optional):如果
get_infos
為True
,則這是一個大小為 的張量,其中非零值表示矩陣或每個小批次的因式分解是否成功或失敗。
- 回傳類型
(Tensor, IntTensor, IntTensor (optional))
範例
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.lu(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) >>> A_LU, pivots, info = torch.lu(A, get_infos=True) >>> if info.nonzero().size(0) == 0: ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!