捷徑

functorch.grad

functorch.grad(func, argnums=0, has_aux=False)[原始碼]

grad 運算子幫助計算 func 關於 argnums 指定的輸入的梯度。此運算子可以嵌套以計算高階梯度。

參數
  • func (Callable) – 一個 Python 函式,接受一個或多個參數。必須返回單元素張量。如果指定 has_auxTrue,則函式可以返回單元素張量和其他輔助物件的元組:(output, aux)

  • argnums (intTuple[int]) – 指定要計算梯度的參數。 argnums 可以是單個整數或整數元組。預設值:0。

  • has_aux (bool) – 標記,指示 func 返回一個張量和其他輔助物件:(output, aux)。預設值:False。

回傳值

計算輸入梯度的函式。預設情況下,函式的輸出是關於第一個參數的梯度張量。如果指定 has_auxTrue,則返回梯度和輸出輔助物件的元組。如果 argnums 是一個整數元組,則返回關於每個 argnums 值的輸出梯度元組。

使用 grad 的範例

>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

當與 vmap 組合使用時,grad 可用於計算每個樣本的梯度。

>>> # xdoctest: +SKIP
>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

使用帶有 has_auxargnumsgrad 的範例

>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> def my_loss_func(y, y_pred):
>>>    loss_per_sample = (0.5 * y_pred - y) ** 2
>>>    loss = loss_per_sample.mean()
>>>    return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))

注意事項

將 PyTorch 的 torch.no_gradgrad 一起使用。

情況 1:在函式內使用 torch.no_grad

>>> # xdoctest: +SKIP
>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在這種情況下,grad(f)(x) 將遵循內部的 torch.no_grad

情況 2:在 torch.no_grad 上下文管理器內使用 grad

>>> # xdoctest: +SKIP
>>> with torch.no_grad():
>>>     grad(f)(x)

在這種情況下,grad 將遵循內部的 torch.no_grad,但不遵循外部的。這是因為 grad 是一個「函式轉換」:其結果不應取決於 f 之外的上下文管理器的結果。

警告

我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.grad 已從 PyTorch 2.0 開始棄用,並將在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.func.grad;有關更多詳細資訊,請參閱 PyTorch 2.0 版本說明和/或 torch.func 遷移指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件

取得 PyTorch 完整的開發者文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源