torch.autograd.Function.backward¶
- static Function.backward(ctx, *grad_outputs)[原始碼]¶
定義使用反向模式自動微分來區分運算的公式。
所有子類別都要覆寫此函數。(定義此函數等同於定義
vjp
函數。)它必須接受一個 context
ctx
作為第一個參數,後面接著與forward()
回傳值數量相同的輸出 (對於 forward 函數的非 tensor 輸出,將傳入 None),並且它應該回傳與forward()
輸入數量相同的 tensor。 每個參數是相對於給定輸出的梯度,並且每個回傳值應該是相對於相應輸入的梯度。 如果輸入不是 Tensor 或是一個不需要梯度的 Tensor,您可以直接傳遞 None 作為該輸入的梯度。context 可以用於檢索在 forward 過程中儲存的 tensors。 它也有一個屬性
ctx.needs_input_grad
,作為一個布林值的 tuple,表示每個輸入是否需要梯度。 例如,如果backward()
的第一個輸入需要計算相對於輸出的梯度,則backward()
將會具有ctx.needs_input_grad[0] = True
。- 回傳型別