捷徑

functorch.hessian

functorch.hessian(func, argnums=0)[原始碼]

透過正向反向策略計算 `func` 關於索引 `argnum` 處的參數的黑塞矩陣。

正向疊反向策略(組成 `jacfwd(jacrev(func)))是兼顧效能的良好預設選項。也可以透過其他 jacfwd()jacrev() 的組合方式來計算 Hessian 矩陣,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

參數
  • func (函式) – 一個 Python 函式,它接受一個或多個引數(其中至少一個必須是張量),並返回一個或多個張量。

  • argnums (intTuple[int]) – 選用,整數或整數元組,指定要計算 Hessian 矩陣的引數位置。預設值:0。

回傳值

返回一個函式,它接受與 func 相同的輸入,並返回 func 關於 argnums 指定引數的 Hessian 矩陣。

注意事項

您可能會看到這個 API 因為「運算子 X 未實作正向模式自動微分」而錯誤。如果是這樣,請提交錯誤報告,我們會優先處理。另一種替代方案是使用 jacrev(jacrev(func)),它有更好的運算子覆蓋率。

對於 R^N → R^1 函式的基本用法會得到一個 N x N 的 Hessian 矩陣。

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

警告

我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.hessian 從 PyTorch 2.0 開始已被棄用,並將在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.func.hessian;有關更多詳細資訊,請參閱 PyTorch 2.0 發行說明和/或 torch.func 遷移指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源