捷徑

torch.autograd.Function.forward

static Function.forward(*args, **kwargs)[source]

定義自定義 autograd Function 的 forward 傳遞。

所有子類別都會覆寫此函式。有兩種定義 forward 傳遞的方法

用法 1 (合併 forward 和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2(分離 forward 和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 參數。

  • 您還必須覆寫 torch.autograd.Function.setup_context() 靜態方法來處理 ctx 物件的設定。output 是 forward 的輸出,inputs 是 forward 輸入的 Tuple。

  • 詳情請參閱 擴展 torch.autograd

context 可用於儲存任意資料,這些資料可以在 backward 傳遞期間檢索。張量不應直接儲存在 ctx 上(儘管目前為了向後相容性,沒有強制執行)。相反,張量應該使用 ctx.save_for_backward() 儲存(如果它們打算在 backward(等效於 vjp)中使用)或使用 ctx.save_for_forward() 儲存(如果它們打算在 jvp 中使用)。

回傳類型

Any

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源