functorch.grad¶
-
functorch.
grad
(func, argnums=0, has_aux=False)[原始碼]¶ grad
運算子幫助計算func
關於argnums
指定的輸入的梯度。此運算子可以嵌套以計算高階梯度。- 參數
- 回傳值
計算輸入梯度的函式。預設情況下,函式的輸出是關於第一個參數的梯度張量。如果指定
has_aux
為True
,則返回梯度和輸出輔助物件的元組。如果argnums
是一個整數元組,則返回關於每個argnums
值的輸出梯度元組。
使用
grad
的範例>>> # xdoctest: +SKIP >>> from torch.func import grad >>> x = torch.randn([]) >>> cos_x = grad(lambda x: torch.sin(x))(x) >>> assert torch.allclose(cos_x, x.cos()) >>> >>> # Second-order gradients >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) >>> assert torch.allclose(neg_sin_x, -x.sin())
當與
vmap
組合使用時,grad
可用於計算每個樣本的梯度。>>> # xdoctest: +SKIP >>> from torch.func import grad, vmap >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): >>> # Very simple linear model with activation >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> >>> def compute_loss(weights, example, target): >>> y = model(weights, example) >>> return ((y - target) ** 2).mean() # MSELoss >>> >>> weights = torch.randn(feature_size, requires_grad=True) >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
使用帶有
has_aux
和argnums
的grad
的範例>>> # xdoctest: +SKIP >>> from torch.func import grad >>> def my_loss_func(y, y_pred): >>> loss_per_sample = (0.5 * y_pred - y) ** 2 >>> loss = loss_per_sample.mean() >>> return loss, (y_pred, loss_per_sample) >>> >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) >>> y_true = torch.rand(4) >>> y_preds = torch.rand(4, requires_grad=True) >>> out = fn(y_true, y_preds) >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
注意事項
將 PyTorch 的
torch.no_grad
與grad
一起使用。情況 1:在函式內使用
torch.no_grad
>>> # xdoctest: +SKIP >>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在這種情況下,
grad(f)(x)
將遵循內部的torch.no_grad
。情況 2:在
torch.no_grad
上下文管理器內使用grad
>>> # xdoctest: +SKIP >>> with torch.no_grad(): >>> grad(f)(x)
在這種情況下,
grad
將遵循內部的torch.no_grad
,但不遵循外部的。這是因為grad
是一個「函式轉換」:其結果不應取決於f
之外的上下文管理器的結果。警告
我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.grad 已從 PyTorch 2.0 開始棄用,並將在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.func.grad;有關更多詳細資訊,請參閱 PyTorch 2.0 版本說明和/或 torch.func 遷移指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html