torch.func API 參考¶
函數轉換¶
vmap 是向量化的 map; |
|
|
|
回傳一個函式,用於計算梯度和原始值 (primal) 或前向 (forward) 計算的 tuple。 |
|
代表 vector-Jacobian product (向量-雅可比乘積),回傳一個 tuple,包含 |
|
代表 Jacobian-vector product (雅可比-向量乘積),回傳一個 tuple,包含 func(*primals) 的輸出,以及 "在 |
|
回傳 |
|
使用反向模式自動微分計算 |
|
使用正向模式自動微分計算 |
|
透過 forward-over-reverse 策略計算 |
|
functionalize 是一種轉換,可用於從函式中移除(中間的)變更 (mutations) 和別名 (aliasing),同時保留函式的語義。 |
用於處理 torch.nn.Modules 的實用工具¶
一般來說,您可以對呼叫 torch.nn.Module
的函式進行轉換。例如,以下是如何計算一個接收三個值並回傳三個值的函式的雅可比矩陣的範例
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
但是,如果您想要執行諸如計算模型參數的雅可比矩陣之類的操作,則需要有一種建構函式的方法,其中參數是函式的輸入。這就是 functional_call()
的用途:它接受一個 nn.Module、轉換後的 parameters
以及 Module 前向傳遞的輸入。它回傳使用替換的參數執行 Module 前向傳遞的值。
以下是如何計算參數的雅可比矩陣
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
透過將模組參數和緩衝區替換為提供的參數和緩衝區,對模組執行功能呼叫。 |
|
準備一個 torch.nn.Modules 的清單,以便與 |
|
透過將 |
如果您正在尋找有關修復 Batch Norm 模組的資訊,請按照此處的指南操作