torch.func.linearize¶
- torch.func.linearize(func, *primals)[原始碼]¶
傳回
func
在primals
的值以及在primals
的線性近似。- 參數
func (Callable) – 一個接受一個或多個參數的 Python 函數。
primals (Tensors) – 傳遞給
func
的位置引數,必須都是 Tensor。這些是函數進行線性近似的值。
- 傳回
傳回一個
(output, jvp_fn)
tuple,其中包含將func
應用於primals
的輸出,以及一個計算func
在primals
處評估的 jvp 的函式。- 回傳類型
如果需要在
primals
處多次計算 jvp,則 linearize 非常有用。但是,為了實現這一點,linearize 會儲存中間計算結果,並且比直接應用 jvp 具有更高的記憶體需求。因此,如果所有tangents
都是已知的,則計算 vmap(jvp) 可能比使用 linearize 更有效。注意
linearize 會評估
func
兩次。 請針對單次評估的實作提交 issue。- 範例:
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>