torch.nn.utils.stateless.functional_call¶
- torch.nn.utils.stateless.functional_call(module, parameters_and_buffers, args=None, kwargs=None, *, tie_weights=True, strict=False)[原始碼][原始碼]¶
透過將模組的參數和緩衝區替換為提供的參數和緩衝區,來對模組執行函數式呼叫。
警告
此 API 已於 PyTorch 2.0 起棄用,並將在 PyTorch 的未來版本中移除。請改用
torch.func.functional_call()
,它可以直接取代此 API。注意
如果模組具有啟用的參數化,則在
parameters_and_buffers
參數中傳遞一個值,且該值的名稱設定為常規參數名稱,將會完全停用參數化。 如果您想將參數化函數應用於傳遞的值,請將鍵設定為{submodule_name}.parametrizations.{parameter_name}.original
。注意
如果模組對參數/緩衝區執行原地 (in-place) 操作,這些操作將反映在 parameters_and_buffers 輸入中。
範例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does self.foo = self.foo + 1 >>> print(mod.foo) # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print(mod.foo) # tensor(0.) >>> print(a['foo']) # tensor(1.)
注意
如果模組具有綁定權重,則 functional_call 是否遵守綁定取決於 tie_weights 標誌。
範例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied >>> print(mod.foo) # tensor(1.) >>> mod(torch.zeros(())) # tensor(2.) >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
- 參數
module (torch.nn.Module) – 要呼叫的模組
parameters_and_buffers (dict of str 和 Tensor) – 將在模組呼叫中使用的參數。
args (Any 或 tuple) – 要傳遞給模組呼叫的引數。 如果不是元組,則視為單一引數。
kwargs (dict) – 要傳遞給模組呼叫的關鍵字引數
tie_weights (bool, optional) – 如果為 True,則在原始模型中綁定的參數和緩衝區將被視為在重新參數化的版本中是綁定的。 因此,如果為 True 並且傳遞了綁定參數和緩衝區的不同值,則會產生錯誤。 如果為 False,則除非傳遞給兩個權重的值相同,否則它將不遵守原始綁定的參數和緩衝區。 預設值:True。
strict (bool, optional) – 如果為 True,則傳遞的參數和緩衝區必須與原始模組中的參數和緩衝區匹配。 因此,如果為 True 並且有任何遺失或未預期的鍵,則會產生錯誤。 預設值:False。
- 回傳
呼叫
module
的結果。- 回傳類型
Any