functorch.make_functional_with_buffers¶
-
functorch。
make_functional_with_buffers
(模型, 非活化_自動梯度_追蹤=偽)[原始碼]¶ make_functional(模型,非活化_自動梯度_追蹤=False) -> 函式,參數
給定
torch.nn.Module
,make_functional()
萃取狀態(參數)並傳回模型的函數版本函式
。如此一來,就可以使用模型
參數的轉換。函式
可依下列方式呼叫import torch import torch.nn as nn from functorch import make_functional x = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) func(params, x)
以下是一個範例,說明如何將梯度轉換套用在模型的參數上。
import torch import torch.nn as nn from functorch import make_functional, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(params, x, t)
如果模型有任何緩衝,請改用
make_functional_with_buffers()
。- 引數
模型 (torch.nn.Module) – 輸入模型。
非活化_自動梯度_追蹤 (布林值) – 旗標,用於禁用輸出參數的梯度追蹤。傳回的參數與原始模型的參數集合無關。如果為 False(預設),參數會
requires_grad=True
(亦即可以用一般 PyTorch 自動梯度追蹤),而且會符合原始模型參數的 requires_grad-ness。否則,傳回的參數會是requires_grad=False
。預設為 False。如果你打算使用一般 PyTorch 自動梯度(例如,你想要呼叫.backward()
或torch.autograd.grad()
,請將非活化_自動梯度_追蹤=False
。否則,如果你只打算使用 functorch 的梯度轉換,請設定非活化_自動梯度_追蹤=True
,以避免使用 PyTorch 自動梯度不必要地追蹤歷程。
警告
我們已將 functorch 整合至 PyTorch。在整合的最後一步驟中,functorch.make_functional_with_buffers 從 PyTorch 2.0 開始已被棄用,並將在未來版本的 PyTorch >= 2.3 中刪除。請改用 torch.func.functional_call;有關詳細資訊,請參閱 PyTorch 2.0 發行筆記和/或 torch.func 遷移指南,https://pytorch.dev.org.tw/docs/master/func.migrating.html