functorch.make_functional¶
-
functorch.
make_functional
(model, disable_autograd_tracking=False) → func, params[source]¶ 給定一個
torch.nn.Module
,make_functional()
會萃取狀態 (params) 並傳回模型的函式版本func
。這使得可以對model
之參數使用轉換。func
可如下呼叫import torch import torch.nn as nn from functorch import make_functional x = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) func(params, x)
以下是將 grad 轉換套用於模型參數的範例。
import torch import torch.nn as nn from functorch import make_functional, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(params, x, t)
如果模型有任何緩衝區,請使用
make_functional_with_buffers()
取代。- 參數
model (torch.nn.Module) – 輸入模型。
disable_autograd_tracking (bool) – 旗標,用於停用輸出參數的梯度追蹤。傳回的 params 與原始模型的 params 集合無關。如果是 False (預設值),params 會有
requires_grad=True
(亦即它們可透過正規 PyTorch 自動微分追蹤),與原始模型的 params 的 requires_grad 相符。否則,傳回的 params 會有requires_grad=False
。預設值:False。如果你計畫使用正規的 PyTorch 自動微分 (例如:如果你要呼叫.backward()
或torch.autograd.grad()
,請設定disable_autograd_tracking=False
。否則,如果你只計畫使用 functorch 的梯度轉換,請設定disable_autograd_tracking=True
,以避免透過 PyTorch 自動微分不必要追蹤歷程。
警告
我們已將 functorch 整合至 PyTorch。身為整合的最後一步,從 PyTorch 2.0 開始,functorch.make_functional 已不建議使用,並會在 PyTorch >= 2.3 的未來版本中移除。請使用 torch.func.functional_call 取代;請參閱 PyTorch 2.0 發行說明和/或 torch.func 遷移指南以取得更多詳情 https://pytorch.dev.org.tw/docs/master/func.migrating.html