快捷方式

torch.func

torch.func,先前稱為「functorch」,是類似 JAX 的 PyTorch 可組合函式轉換。

注意

這個函式庫目前處於 beta 階段。 這表示功能通常可以運作(除非另有說明),並且我們(PyTorch 團隊)致力於推進這個函式庫。 但是,API 可能會根據使用者回饋進行變更,並且我們沒有涵蓋 PyTorch 的所有操作。

如果您對 API 有任何建議,或想要涵蓋的使用案例,請開啟 GitHub issue 或與我們聯繫。 我們很樂意了解您如何使用這個函式庫。

什麼是可組合函式轉換?

  • 「函式轉換」是一個高階函式,它接受一個數值函式並傳回一個新的函式,該函式計算不同的量。

  • torch.func 具有自動微分轉換(grad(f) 傳回一個計算 f 梯度的函式)、向量化/批次處理轉換(vmap(f) 傳回一個計算 f 在輸入批次上的函式)等等。

  • 這些函數轉換可以任意組合。例如,組合 vmap(grad(f)) 會計算一個稱為「每個樣本梯度 (per-sample-gradients)」的量,這是目前標準 PyTorch 無法有效率計算的。

為什麼要可組合的函數轉換?

目前在 PyTorch 中,有許多用例很難做到

  • 計算每個樣本的梯度(或其他每個樣本的量)

  • 在單一機器上運行模型的集成

  • 在 MAML 的內部迴圈中有效率地將任務批次處理在一起

  • 有效率地計算 Jacobian 和 Hessian 矩陣

  • 有效率地計算批次化的 Jacobian 和 Hessian 矩陣

組合 vmap()grad()vjp() 轉換,讓我們無需為每個用例設計單獨的子系統即可表達上述情況。這種可組合函數轉換的想法來自 JAX 框架

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源