torch.func.vjp¶
- torch.func.vjp(func, *primals, has_aux=False)[原始碼]¶
代表向量-雅可比乘積 (vector-Jacobian product),回傳一個包含以下結果的元組:將
func
應用於primals
的結果,以及一個函數。當給定cotangents
時,該函數計算func
相對於primals
的反向模式雅可比矩陣乘以cotangents
的值。- 參數
func (Callable) – 一個接受一或多個參數的 Python 函數。必須回傳一或多個 Tensor。
primals (Tensors) – 傳遞給
func
的位置參數,這些參數都必須是 Tensor。回傳的函數也會計算相對於這些參數的導數。has_aux (bool) – 一個旗標,指示
func
回傳一個(output, aux)
元組,其中第一個元素是要微分的函數的輸出,第二個元素是不會被微分的其他輔助物件。預設值:False。
- 回傳
回傳一個
(output, vjp_fn)
元組,包含將func
應用於primals
的輸出,以及一個計算func
相對於所有primals
的 vjp 的函數,計算時會使用傳遞給回傳函數的 cotangents。如果has_aux is True
,則改為回傳一個(output, vjp_fn, aux)
元組。回傳的vjp_fn
函數將回傳一個包含每個 VJP 的元組。
在簡單的情況下使用時,
vjp()
的行為與grad()
相同。>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
但是,
vjp()
可以透過傳入每個輸出的 cotangents 來支援具有多個輸出的函數。>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
甚至可以支援輸出為 Python 結構的情況。>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
回傳的函數將會計算相對於每個primals
的偏導數。>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals
是f
的位置參數。所有 kwargs 都使用它們的預設值。>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
注意
將 PyTorch
torch.no_grad
與vjp
一起使用。情況 1:在函數內部使用torch.no_grad
。>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在這種情況下,
vjp(f)(x)
將尊重內部的torch.no_grad
。情況 2:在
torch.no_grad
上下文管理器內使用vjp
。>>> with torch.no_grad(): >>> vjp(f)(x)
在這種情況下,
vjp
將尊重內部的torch.no_grad
,但不尊重外部的。這是因為vjp
是一個「函數轉換」:其結果不應取決於f
外部的上下文管理器的結果。