捷徑

torch.Tensor.register_post_accumulate_grad_hook

Tensor.register_post_accumulate_grad_hook(hook)[原始碼][原始碼]

註冊一個反向鉤子 (backward hook),該鉤子在梯度累積之後執行。

此鉤子將在張量的所有梯度累積完成後被呼叫,這表示該張量的 .grad 欄位已更新。後累積梯度鉤子僅適用於葉張量 (leaf tensors,沒有 .grad_fn 欄位的張量)。在非葉張量上註冊此鉤子將會出錯!

該鉤子應具有以下簽名:

hook(param: Tensor) -> None

請注意,與其他 autograd 鉤子不同,此鉤子作用於需要梯度的張量,而不是梯度本身。該鉤子可以原地修改和存取其 Tensor 參數,包括其 .grad 欄位。

此函數返回一個帶有 handle.remove() 方法的處理程式 (handle),該方法從模組中移除鉤子。

注意

有關此鉤子何時執行,以及其相對於其他鉤子的執行順序的更多資訊,請參閱 反向鉤子執行。由於此鉤子在反向傳播期間執行,因此它將在 no_grad 模式下執行 (除非 create_graph 為 True)。如果需要,您可以使用 torch.enable_grad() 在鉤子內重新啟用 autograd。

範例

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

獲取初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題的解答

檢視資源