捷徑

functorch

警告

我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch API 從 PyTorch 2.0 開始已被棄用。請改用 torch.func API,並參閱遷移指南文件以瞭解更多詳細資訊。

函數變換

vmap

vmap 是向量化映射;vmap(func) 會傳回一個新的函數,該函數會將 func 映射到輸入的某個維度。

grad

grad 運算子幫助計算 func 相對於 argnums 指定的輸入的梯度。

grad_and_value

傳回一個函數,用於計算梯度和原始(或正向)計算的元組。

vjp

代表向量-雅可比矩陣乘積,傳回一個元組,其中包含將 func 應用於 primals 的結果,以及一個函數,當給定 cotangents 時,計算 func 相對於 primals 的反向模式雅可比矩陣乘以 cotangents 的結果。

jvp

代表雅可比矩陣-向量乘積,傳回一個元組,其中包含 func(*primals) 的輸出和「在 primals 處評估的 func 的雅可比矩陣」乘以 tangents 的結果。

jacrev

使用反向模式自動微分計算 func 相對於索引 argnum 處的參數的雅可比矩陣

jacfwd

使用正向模式自動微分計算 func 相對於索引 argnum 處的參數的雅可比矩陣

hessian

通過正向-反向策略計算 func 相對於索引 argnum 處的參數的海森矩陣。

functionalize

functionalize 是一個可以用於從函數中移除(中間)突變和別名的變換,同時保留函數的語義。

用於處理 torch.nn.Modules 的工具

一般來說,您可以對呼叫 torch.nn.Module 的函數進行變換。例如,以下是如何計算一個函數的雅可比矩陣的範例,該函數接受三個值並傳回三個值

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

但是,如果您想做一些事情,例如計算模型參數的雅可比矩陣,那麼需要一種方法來建構一個函數,其中參數是函數的輸入。這就是 make_functional()make_functional_with_buffers() 的用途:給定一個 torch.nn.Module,它們會傳回一個新的函數,該函數接受 parameters 和模組正向傳遞的輸入。

make_functional

給定一個 torch.nn.Modulemake_functional() 會提取狀態 (params) 並傳回模型的函數版本 func

make_functional_with_buffers

make_functional(model, disable_autograd_tracking=False) -> func, params

combine_state_for_ensemble

使用 vmap() 為集成準備 torch.nn.Modules 列表。

如果您正在尋找有關修復批次正規化模組的資訊,請遵循此處的指南

文件

取得 PyTorch 的完整開發者說明文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源