torch.Tensor.register_post_accumulate_grad_hook¶
- Tensor.register_post_accumulate_grad_hook(hook)[原始碼][原始碼]¶
註冊一個反向鉤子 (backward hook),該鉤子在梯度累積之後執行。
此鉤子將在張量的所有梯度累積完成後被呼叫,這表示該張量的 .grad 欄位已更新。後累積梯度鉤子僅適用於葉張量 (leaf tensors,沒有 .grad_fn 欄位的張量)。在非葉張量上註冊此鉤子將會出錯!
該鉤子應具有以下簽名:
hook(param: Tensor) -> None
請注意,與其他 autograd 鉤子不同,此鉤子作用於需要梯度的張量,而不是梯度本身。該鉤子可以原地修改和存取其 Tensor 參數,包括其 .grad 欄位。
此函數返回一個帶有
handle.remove()
方法的處理程式 (handle),該方法從模組中移除鉤子。注意
有關此鉤子何時執行,以及其相對於其他鉤子的執行順序的更多資訊,請參閱 反向鉤子執行。由於此鉤子在反向傳播期間執行,因此它將在 no_grad 模式下執行 (除非 create_graph 為 True)。如果需要,您可以使用 torch.enable_grad() 在鉤子內重新啟用 autograd。
範例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> lr = 0.01 >>> # simulate a simple SGD update >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) >>> v.backward(torch.tensor([1., 2., 3.])) >>> v tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) >>> h.remove() # removes the hook