torch.svd¶
- torch.svd(input, some=True, compute_uv=True, *, out=None)¶
計算矩陣或批次矩陣
input
的奇異值分解 (SVD)。奇異值分解表示為一個 namedtuple (U, S, V),使得input
。其中 是 V 的轉置(對於實數輸入),以及 V 的共軛轉置(對於複數輸入)。如果input
是一批矩陣,那麼 U、S 和 V 也會以與input
相同的批次維度進行批次處理。如果
some
為 True(預設值),則此方法會返回簡化的奇異值分解。在這種情況下,如果input
的最後兩個維度為 m 和 n,則返回的 U 和 V 矩陣將僅包含 min(n, m) 個正交列。如果
compute_uv
為 False,則返回的 U 和 V 將會是填滿零的矩陣,形狀分別為 (m, m) 和 (n, n),並且與input
具有相同的裝置。當compute_uv
為 False 時,參數some
無效。支援 float、double、cfloat 和 cdouble 資料類型的
input
。U 和 V 的 dtypes 與input
的 dtypes 相同。S 將始終為實數值,即使input
是複數也是如此。警告
torch.svd()
已棄用,建議使用torch.linalg.svd()
,並將在未來的 PyTorch 版本中移除。U, S, V = torch.svd(A, some=some, compute_uv=True)
(預設) 應替換為U, S, Vh = torch.linalg.svd(A, full_matrices=not some) V = Vh.mH
_, S, _ = torch.svd(A, some=some, compute_uv=False)
應替換為S = torch.linalg.svdvals(A)
注意
與
torch.linalg.svd()
的差異some
與torch.linalg.svd()
的full_matrices
相反。請注意,兩者的預設值均為 True,因此預設行為實際上是相反的。torch.svd()
返回 V,而torch.linalg.svd()
返回 Vh,即 。如果
compute_uv
為 False,torch.svd()
會為 U 和 Vh 返回填滿零的張量,而torch.linalg.svd()
會返回空張量。
注意
奇異值會以降序返回。如果
input
是一批矩陣,則批次中每個矩陣的奇異值會以降序返回。注意
只有當
compute_uv
為 True 時,S 張量才能用於計算梯度。注意
當
some
為 False 時,在反向傳播中,U[…, :, min(m, n):] 和 V[…, :, min(m, n):] 上的梯度將被忽略,因為這些向量可以是相應子空間的任意基底。注意
在 CPU 上,
torch.linalg.svd()
的實作使用 LAPACK 的常式 ?gesdd(一種分而治之演算法)而不是 ?gesvd 以提高速度。類似地,在 GPU 上,CUDA 10.1.243 及更高版本使用 cuSOLVER 的常式 gesvdj 和 gesvdjBatched,而 CUDA 早期版本則使用 MAGMA 的常式 gesdd。注意
回傳的 U 將不會是連續的。該矩陣(或批次的矩陣)將表示為列優先矩陣(即 Fortran 連續)。
警告
只有當輸入沒有零或重複的奇異值時,相對於 U 和 V 的梯度才是有限的。
警告
如果任何兩個奇異值之間的距離接近於零,則相對於 U 和 V 的梯度在數值上將是不穩定的,因為它們取決於 . 當矩陣具有小的奇異值時也會發生同樣情況,因為這些梯度也取決於 S^{-1}。
警告
對於複數值的
input
,奇異值分解不是唯一的,因為 U 和 V 可以乘以任意相位因子 在每一列上。當input
具有重複的奇異值時也會發生同樣情況,其中可以將 U 和 V 中 spanning 子空間的列乘以旋轉矩陣,並且 結果向量將跨越相同的子空間。不同的平台(例如 NumPy)或不同設備類型上的輸入可能會產生不同的 U 和 V 張量。- 參數
- 關鍵字引數
out (tuple, optional) – 張量的輸出元組
範例
>>> a = torch.randn(5, 3) >>> a tensor([[ 0.2364, -0.7752, 0.6372], [ 1.7201, 0.7394, -0.0504], [-0.3371, -1.0584, 0.5296], [ 0.3550, -0.4022, 1.5569], [ 0.2445, -0.0158, 1.1414]]) >>> u, s, v = torch.svd(a) >>> u tensor([[ 0.4027, 0.0287, 0.5434], [-0.1946, 0.8833, 0.3679], [ 0.4296, -0.2890, 0.5261], [ 0.6604, 0.2717, -0.2618], [ 0.4234, 0.2481, -0.4733]]) >>> s tensor([2.3289, 2.0315, 0.7806]) >>> v tensor([[-0.0199, 0.8766, 0.4809], [-0.5080, 0.4054, -0.7600], [ 0.8611, 0.2594, -0.4373]]) >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) tensor(8.6531e-07) >>> a_big = torch.randn(7, 5, 3) >>> u, s, v = torch.svd(a_big) >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) tensor(2.6503e-06)