torch.func.hessian¶
- torch.func.hessian(func, argnums=0)[來源]¶
計算
func
相對於索引argnum
處的參數的 Hessian 矩陣,透過 forward-over-reverse 策略。對於良好的效能而言,正向疊反向策略(組成
jacfwd(jacrev(func))
)是一個好的預設選擇。也可以透過jacfwd()
和jacrev()
的其他組合來計算 Hessian 矩陣,例如jacfwd(jacfwd(func))
或jacrev(jacrev(func))
。- 參數
- 回傳
回傳一個函式,該函式接受與
func
相同的輸入,並回傳func
針對位於argnums
的引數所計算的 Hessian 矩陣。
注意
您可能會看到此 API 發生錯誤,顯示「運算子 X 未實作正向模式 AD」。如果發生這種情況,請提交錯誤報告,我們將優先處理。另一種方法是使用
jacrev(jacrev(func))
,它具有更好的運算子涵蓋範圍。使用 R^N -> R^1 函式的基本用法會產生一個 N x N Hessian 矩陣
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))