torch.autograd.Function.forward¶
- static Function.forward(*args, **kwargs)[source]¶
定義自定義 autograd Function 的 forward 傳遞。
所有子類別都會覆寫此函式。有兩種定義 forward 傳遞的方法
用法 1 (合併 forward 和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必須接受一個 context ctx 作為第一個參數,後接任意數量的參數(張量或其他類型)。
用法 2(分離 forward 和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 參數。
您還必須覆寫
torch.autograd.Function.setup_context()
靜態方法來處理ctx
物件的設定。output
是 forward 的輸出,inputs
是 forward 輸入的 Tuple。詳情請參閱 擴展 torch.autograd
context 可用於儲存任意資料,這些資料可以在 backward 傳遞期間檢索。張量不應直接儲存在 ctx 上(儘管目前為了向後相容性,沒有強制執行)。相反,張量應該使用
ctx.save_for_backward()
儲存(如果它們打算在backward
(等效於vjp
)中使用)或使用ctx.save_for_forward()
儲存(如果它們打算在jvp
中使用)。- 回傳類型