捷徑

torch.func.linearize

torch.func.linearize(func, *primals)[原始碼]

傳回 funcprimals 的值以及在 primals 的線性近似。

參數
  • func (Callable) – 一個接受一個或多個參數的 Python 函數。

  • primals (Tensors) – 傳遞給 func 的位置引數,必須都是 Tensor。這些是函數進行線性近似的值。

傳回

傳回一個 (output, jvp_fn) tuple,其中包含將 func 應用於 primals 的輸出,以及一個計算 funcprimals 處評估的 jvp 的函式。

回傳類型

Tuple[Any, Callable]

如果需要在 primals 處多次計算 jvp,則 linearize 非常有用。但是,為了實現這一點,linearize 會儲存中間計算結果,並且比直接應用 jvp 具有更高的記憶體需求。因此,如果所有 tangents 都是已知的,則計算 vmap(jvp) 可能比使用 linearize 更有效。

注意

linearize 會評估 func 兩次。 請針對單次評估的實作提交 issue。

範例:
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您問題的解答

檢視資源