捷徑

functorch.make_functional

functorch.make_functional(model, disable_autograd_tracking=False)func, params[source]

給定一個 torch.nn.Modulemake_functional() 會萃取狀態 (params) 並傳回模型的函式版本 func。這使得可以對 model 之參數使用轉換。

func 可如下呼叫

import torch
import torch.nn as nn
from functorch import make_functional

x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
func(params, x)

以下是將 grad 轉換套用於模型參數的範例。

import torch
import torch.nn as nn
from functorch import make_functional, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)

def compute_loss(params, x, t):
    y = func(params, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(params, x, t)

如果模型有任何緩衝區,請使用 make_functional_with_buffers() 取代。

參數
  • model (torch.nn.Module) – 輸入模型。

  • disable_autograd_tracking (bool) – 旗標,用於停用輸出參數的梯度追蹤。傳回的 params 與原始模型的 params 集合無關。如果是 False (預設值),params 會有 requires_grad=True (亦即它們可透過正規 PyTorch 自動微分追蹤),與原始模型的 params 的 requires_grad 相符。否則,傳回的 params 會有 requires_grad=False。預設值:False。如果你計畫使用正規的 PyTorch 自動微分 (例如:如果你要呼叫 .backward()torch.autograd.grad(),請設定 disable_autograd_tracking=False。否則,如果你只計畫使用 functorch 的梯度轉換,請設定 disable_autograd_tracking=True,以避免透過 PyTorch 自動微分不必要追蹤歷程。

警告

我們已將 functorch 整合至 PyTorch。身為整合的最後一步,從 PyTorch 2.0 開始,functorch.make_functional 已不建議使用,並會在 PyTorch >= 2.3 的未來版本中移除。請使用 torch.func.functional_call 取代;請參閱 PyTorch 2.0 發行說明和/或 torch.func 遷移指南以取得更多詳情 https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件

瀏覽 PyTorch 的全面開發人員文件

檢視文件

教學課程

取得初學者與進階開發人員的深入教學

檢視教學

資源

尋找開發資源並讓你的問題獲得解答

檢視資源