torch.Tensor.register_hook¶
- Tensor.register_hook(hook)[source][source]¶
註冊一個向後鉤子 (backward hook)。
每次計算相對於 Tensor 的梯度時,都會呼叫 hook。Hook 應具有以下簽名:
hook(grad) -> Tensor or None
Hook 不應修改其引數,但它可以選擇性地返回一個新的梯度,該梯度將取代
grad
。此函式返回一個帶有
handle.remove()
方法的 handle,該方法會從模組中移除 hook。注意
請參閱 反向 Hook 執行 以取得關於何時執行此 hook,以及其執行順序相對於其他 hook 的更多資訊。
範例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.tensor([1., 2., 3.])) >>> v.grad 2 4 6 [torch.FloatTensor of size (3,)] >>> h.remove() # removes the hook