捷徑

torch.func.hessian

torch.func.hessian(func, argnums=0)[來源]

計算 func 相對於索引 argnum 處的參數的 Hessian 矩陣,透過 forward-over-reverse 策略。

對於良好的效能而言,正向疊反向策略(組成 jacfwd(jacrev(func)))是一個好的預設選擇。也可以透過 jacfwd()jacrev() 的其他組合來計算 Hessian 矩陣,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

參數
  • func (function) – 一個 Python 函式,接受一個或多個引數,其中一個必須是 Tensor,並回傳一個或多個 Tensor

  • argnums (intTuple[int]) – 可選,整數或整數元組,表示要針對哪些引數計算 Hessian 矩陣。預設值:0。

回傳

回傳一個函式,該函式接受與 func 相同的輸入,並回傳 func 針對位於 argnums 的引數所計算的 Hessian 矩陣。

注意

您可能會看到此 API 發生錯誤,顯示「運算子 X 未實作正向模式 AD」。如果發生這種情況,請提交錯誤報告,我們將優先處理。另一種方法是使用 jacrev(jacrev(func)),它具有更好的運算子涵蓋範圍。

使用 R^N -> R^1 函式的基本用法會產生一個 N x N Hessian 矩陣

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源