functorch.vjp¶
-
functorch.
vjp
(func, *primals, has_aux=False)[source]¶ 向量雅可比積的簡寫,傳回包含下列結果的元組:將
func
應用到primals
的結果,以及一個函數,該函數在提供cotangents
時,計算func
對於primals
的反向模式雅可比,再乘以cotangents
。- 參數
func (Callable) – 一個接收一或多個引數的 Python 函數。必須傳回一個或多個張量。
primals (Tensors) – 傳遞給
func
的位置引數,所有內容都必須是張量。傳回的函數也會計算關於這些引數的導數has_aux (bool) – 旗標指示
func
傳回一個(output, aux)
元組,其中第一個元素是函數輸出,準備用於區分,而第二個元素則是不會區分的其他輔助物件。預設值:False.
- 傳回
傳回一個
(output, vjp_fn)
元組,其中包含將func
應用到primals
的輸出,以及一個函數,該函數計算func
對於所有primals
的 VJP,使用傳遞給傳回函數的餘切。如果has_aux is True
,則傳回(output, vjp_fn, aux)
元組。傳回的vjp_fn
函數將傳回一個每個 VJP 的元組。
在簡單案例中使用時,
vjp()
的行為與grad()
相同>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
然而,
vjp()
可以支援有多個輸出的函數,方法是傳入每個輸出的餘切>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
甚至可以支援輸出為 Python 結構>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
傳回的函數將計算關於每個primals
的偏微分>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals
是f
的位置引數。所有關鍵字引數都會使用其預設值>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
備註
搭配
vjp
使用 PyTorchtorch.no_grad
。案例 1:在函數內使用torch.no_grad
>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在本例中,
vjp(f)(x)
將會尊重內部的torch.no_grad
。案例 2:在
torch.no_grad
語法管理員中使用vjp
>>> # xdoctest: +SKIP(failing) >>> with torch.no_grad(): >>> vjp(f)(x)
在本例中,
vjp
將會尊重內部的torch.no_grad
,但不會尊重外部的。這是因為vjp
是「函數轉換」:它的結果不應依賴於外部f
的語法管理員的結果。警告
我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.vjp 自 PyTorch 2.0 版開始已標示為不建議使用,且會在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.func.vjp;有關更多詳細資訊,請參閱 PyTorch 2.0 發行說明和/或 torch.func 移轉指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html