快捷方式

複數

複數是可以表示為 a+bja + bj 的數字,其中 a 和 b 是實數,而 j 稱為虛數單位,它滿足方程式 j2=1j^2 = -1。複數經常出現在數學和工程學中,尤其是在訊號處理等主題中。傳統上,許多使用者和函式庫(例如,TorchAudio)透過以形狀為 (...,2)(..., 2) 的浮點張量表示數據來處理複數,其中最後一個維度包含實數和虛數值。

複數資料類型的張量在使用複數時提供更自然的使用者體驗。與模擬它們的浮點張量上的操作相比,複數張量上的操作(例如,torch.mv()torch.matmul())可能更快且更節省記憶體。PyTorch 中涉及複數的操作經過最佳化,可以使用向量化組譯指令和專用核心(例如,LAPACK、cuBlas)。

注意

torch.fft 模組中的頻譜操作支援原生複數張量。

警告

複數張量是一個 beta 功能,可能會發生變更。

建立複數張量

我們支援兩種複數資料類型:torch.cfloattorch.cdouble

>>> x = torch.randn(2,2, dtype=torch.cfloat)
>>> x
tensor([[-0.4621-0.0303j, -0.2438-0.5874j],
     [ 0.7706+0.1421j,  1.2110+0.1918j]])

注意

複數張量的預設資料類型由預設的浮點資料類型決定。如果預設的浮點資料類型是 torch.float64,則推斷複數的資料類型為 torch.complex128,否則假定它們的資料類型為 torch.complex64

除了 torch.linspace()torch.logspace()torch.arange() 之外的所有工廠函數都支援複數張量。

從舊表示法轉換

目前使用形狀為 (...,2)(..., 2) 的實數張量來解決缺少複數張量的使用者,可以使用 torch.view_as_complex()torch.view_as_real() 輕鬆地在他們的程式碼中使用複數張量。請注意,這些函數不會執行任何複製,而是傳回輸入張量的檢視。

>>> x = torch.randn(3, 2)
>>> x
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])
>>> y = torch.view_as_complex(x)
>>> y
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])
>>> torch.view_as_real(y)
tensor([[ 0.6125, -0.1681],
     [-0.3773,  1.3487],
     [-0.0861, -0.7981]])

存取實部和虛部

可以使用 realimag 存取複數張量的實數和虛數值。

注意

存取 realimag 屬性不會分配任何記憶體,並且對 realimag 張量進行原地更新將更新原始複數張量。此外,傳回的 realimag 張量不是連續的。

>>> y.real
tensor([ 0.6125, -0.3773, -0.0861])
>>> y.imag
tensor([-0.1681,  1.3487, -0.7981])

>>> y.real.mul_(2)
tensor([ 1.2250, -0.7546, -0.1722])
>>> y
tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j])
>>> y.real.stride()
(2,)

角度和絕對值

可以使用 torch.angle()torch.abs() 計算複數張量的角度和絕對值。

>>> x1=torch.tensor([3j, 4+4j])
>>> x1.abs()
tensor([3.0000, 5.6569])
>>> x1.angle()
tensor([1.5708, 0.7854])

線性代數

許多線性代數運算,如 torch.matmul()torch.linalg.svd()torch.linalg.solve() 等,都支援複數。如果您想請求我們目前不支援的運算,請搜尋是否已提交問題,如果沒有,請提交一個

序列化

複數張量可以序列化,允許將數據儲存為複數值。

>>> torch.save(y, 'complex_tensor.pt')
>>> torch.load('complex_tensor.pt')
tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j])

Autograd

PyTorch 支援複數張量的 autograd。計算出的梯度是共軛 Wirtinger 導數,其負值正是梯度下降演算法中使用的最陡下降方向。因此,所有現有的最佳化器都可以實現為直接使用複數參數。如需更多詳細資訊,請查看 複數的 Autograd 筆記。

最佳化器

在語義上,我們將使用複數參數逐步執行 PyTorch 最佳化器,定義為等同於在 torch.view_as_real() 等效的複數參數上逐步執行相同的最佳化器。 更具體地說:

>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)]
>>> real_params = [torch.view_as_real(p) for p in params]

>>> complex_optim = torch.optim.AdamW(params)
>>> real_optim = torch.optim.AdamW(real_params)

real_optimcomplex_optim 將對參數計算相同的更新,儘管兩個最佳化器之間可能存在輕微的數值差異,類似於 foreach 與 forloop 最佳化器以及 capturable 與預設最佳化器之間的數值差異。 更多細節請參考 https://pytorch.dev.org.tw/docs/stable/notes/numerical_accuracy.html

具體來說,雖然您可以將我們的最佳化器對複數張量的處理方式視為與分別對它們的 p.realp.imag 部分進行最佳化相同,但實作細節並不完全如此。請注意,torch.view_as_real() 等效項會將複數張量轉換為形狀為 (...,2)(..., 2) 的實數張量,而將複數張量拆分為兩個張量則是 2 個大小為 (...)(...) 的張量。 這種區別對逐點最佳化器(如 AdamW)沒有影響,但會導致在進行全域縮減的最佳化器(如 LBFGS)中產生輕微差異。 我們目前沒有執行逐張量縮減的最佳化器,因此尚未定義此行為。 如果您有需要精確定義此行為的使用案例,請開啟一個 issue。

我們不完全支援以下子系統:

  • 量化

  • JIT

  • 稀疏張量 (Sparse Tensors)

  • 分散式 (Distributed)

如果其中任何一項對您的使用案例有幫助,請搜尋是否已提交 issue,如果沒有,請提交一個

文件

存取全面的 PyTorch 開發人員文件

檢視文件

教學

取得適合初學者和高級開發人員的深入教學課程

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源