functorch.hessian¶
-
functorch.
hessian
(func, argnums=0)[原始碼]¶ 透過正向反向策略計算 `func` 關於索引 `argnum` 處的參數的黑塞矩陣。
正向疊反向策略(組成 `
jacfwd(jacrev(func))
)是兼顧效能的良好預設選項。也可以透過其他jacfwd()
和jacrev()
的組合方式來計算 Hessian 矩陣,例如jacfwd(jacfwd(func))
或jacrev(jacrev(func))
。- 參數
- 回傳值
返回一個函式,它接受與
func
相同的輸入,並返回func
關於argnums
指定引數的 Hessian 矩陣。
注意事項
您可能會看到這個 API 因為「運算子 X 未實作正向模式自動微分」而錯誤。如果是這樣,請提交錯誤報告,我們會優先處理。另一種替代方案是使用
jacrev(jacrev(func))
,它有更好的運算子覆蓋率。對於 R^N → R^1 函式的基本用法會得到一個 N x N 的 Hessian 矩陣。
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))
警告
我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.hessian 從 PyTorch 2.0 開始已被棄用,並將在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.func.hessian;有關更多詳細資訊,請參閱 PyTorch 2.0 發行說明和/或 torch.func 遷移指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html