捷徑

functorch.make_functional_with_buffers

functorch。make_functional_with_buffers(模型, 非活化_自動梯度_追蹤=)[原始碼]

make_functional(模型,非活化_自動梯度_追蹤=False) -> 函式,參數

給定 torch.nn.Modulemake_functional() 萃取狀態(參數)並傳回模型的函數版本 函式。如此一來,就可以使用 模型 參數的轉換。

函式 可依下列方式呼叫

import torch
import torch.nn as nn
from functorch import make_functional

x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
func(params, x)

以下是一個範例,說明如何將梯度轉換套用在模型的參數上。

import torch
import torch.nn as nn
from functorch import make_functional, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)

def compute_loss(params, x, t):
    y = func(params, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(params, x, t)

如果模型有任何緩衝,請改用 make_functional_with_buffers()

引數
  • 模型 (torch.nn.Module) – 輸入模型。

  • 非活化_自動梯度_追蹤 (布林值) – 旗標,用於禁用輸出參數的梯度追蹤。傳回的參數與原始模型的參數集合無關。如果為 False(預設),參數會 requires_grad=True(亦即可以用一般 PyTorch 自動梯度追蹤),而且會符合原始模型參數的 requires_grad-ness。否則,傳回的參數會是 requires_grad=False。預設為 False。如果你打算使用一般 PyTorch 自動梯度(例如,你想要呼叫 .backward()torch.autograd.grad(),請將 非活化_自動梯度_追蹤=False。否則,如果你只打算使用 functorch 的梯度轉換,請設定 非活化_自動梯度_追蹤=True,以避免使用 PyTorch 自動梯度不必要地追蹤歷程。

警告

我們已將 functorch 整合至 PyTorch。在整合的最後一步驟中,functorch.make_functional_with_buffers 從 PyTorch 2.0 開始已被棄用,並將在未來版本的 PyTorch >= 2.3 中刪除。請改用 torch.func.functional_call;有關詳細資訊,請參閱 PyTorch 2.0 發行筆記和/或 torch.func 遷移指南,https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件

取得 PyTorch 的全面開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學

檢視教學

資源

找出開發資源並取得問題解答

檢視資源