functorch¶
functorch 是 類似 JAX 的 PyTorch 可組合函式變換。
警告
我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,自 PyTorch 2.0 起,functorch API 已被棄用。請改用 torch.func API,並參閱 遷移指南 和 文件 以了解更多詳細信息。
什麼是可組合函式變換?¶
「函式變換」是一種高階函式,它接受一個數值函式並返回一個計算不同數量的新函式。
functorch 具有自動微分變換(
grad(f)
返回一個計算f
梯度的函式)、向量化/批次處理變換(vmap(f)
返回一個計算輸入批次上的f
的函式)等等。這些函式變換可以任意相互組合。例如,組合
vmap(grad(f))
計算一個稱為每樣本梯度的數量,而目前 PyTorch 無法有效地計算該數量。
為什麼要使用可組合函式變換?¶
目前在 PyTorch 中,有許多用例難以實現
計算每樣本梯度(或其他每樣本數量)
在單台機器上運行模型集成
在 MAML 的內部循環中有效地批次處理任務
有效地計算雅可比矩陣和黑塞矩陣
有效地計算批次雅可比矩陣和黑塞矩陣
組合 vmap()
、grad()
和 vjp()
變換允許我們在無需為每個子系統設計單獨的子系統的情況下表達上述內容。這種可組合函式變換的想法來自 JAX 框架。