快捷方式

torch.func.functional_call

torch.func.functional_call(module, parameter_and_buffer_dicts, args=None, kwargs=None, *, tie_weights=True, strict=False)[原始碼]

透過替換模組參數和緩衝區為提供的參數和緩衝區,對模組執行函數呼叫。

注意

如果模組具有作用中的參數化,在 parameter_and_buffer_dicts 參數中傳遞一個值,且該值的名稱設定為常規參數名稱,將會完全停用參數化。如果您想要將參數化函數應用於傳遞的值,請將鍵設定為 {submodule_name}.parametrizations.{parameter_name}.original

注意

如果模組對參數/緩衝區執行原地 (in-place) 操作,這些操作將反映在 parameter_and_buffer_dicts 輸入中。

範例

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # does self.foo = self.foo + 1
>>> print(mod.foo)  # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo)  # tensor(0.)
>>> print(a['foo'])  # tensor(1.)

注意

如果模組具有綁定的權重,functional_call 是否遵守綁定取決於 tie_weights 旗標。

範例

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo)  # tensor(1.)
>>> mod(torch.zeros(()))  # tensor(2.)
>>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)

傳遞多個字典的範例

a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)})  # two separate dictionaries
mod = nn.Bar(1, 1)  # return self.weight @ x + self.buffer
print(mod.weight)  # tensor(...)
print(mod.buffer)  # tensor(...)
x = torch.randn((1, 1))
print(x)
functional_call(mod, a, x)  # same as x
print(mod.weight)  # same as before functional_call

這是一個將 grad 轉換應用於模型參數的範例。

import torch
import torch.nn as nn
from torch.func import functional_call, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)

def compute_loss(params, x, t):
    y = functional_call(model, params, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)

注意

如果使用者不需要在 grad 轉換之外進行梯度追蹤,他們可以分離 (detach) 所有參數,以獲得更好的效能和記憶體使用率

範例

>>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
>>> grad_weights = grad(compute_loss)(detached_params, x, t)
>>> grad_weights.grad_fn  # None--it's not tracking gradients outside of grad

這意味著使用者無法呼叫 grad_weight.backward()。但是,如果他們不需要轉換之外的自動微分追蹤,這將導致更少的記憶體使用和更快的速度。

參數
  • module (torch.nn.Module) – 要呼叫的模組

  • parameters_and_buffer_dicts (Dict[str, Tensor] 或 Dict[str, Tensor] 的 tuple) – 將在模組呼叫中使用的參數。如果給定一個字典的元組,它們必須具有不同的鍵,以便可以一起使用所有字典

  • args (Anytuple) – 要傳遞給模組呼叫的參數。如果不是元組,則視為單個參數。

  • kwargs (dict) – 要傳遞給模組呼叫的關鍵字參數

  • tie_weights (bool, optional) – 如果為 True,則在原始模型中綁定的參數和緩衝區將被視為在重新參數化的版本中也被綁定。 因此,如果為 True 且為綁定的參數和緩衝區傳遞了不同的值,則會發生錯誤。 如果為 False,則除非為兩個權重傳遞的值相同,否則它將不尊重原始綁定的參數和緩衝區。 預設值:True。

  • strict (bool, optional) – 如果為 True,則傳入的參數和緩衝區必須與原始模組中的參數和緩衝區匹配。 因此,如果為 True 且存在任何遺失或意外的鍵,則會發生錯誤。 預設值:False。

返回

呼叫 module 的結果。

返回類型

Any

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源