捷徑

torch.func.vjp

torch.func.vjp(func, *primals, has_aux=False)[原始碼]

代表向量-雅可比乘積 (vector-Jacobian product),回傳一個包含以下結果的元組:將 func 應用於 primals 的結果,以及一個函數。當給定 cotangents 時,該函數計算 func 相對於 primals 的反向模式雅可比矩陣乘以 cotangents 的值。

參數
  • func (Callable) – 一個接受一或多個參數的 Python 函數。必須回傳一或多個 Tensor。

  • primals (Tensors) – 傳遞給 func 的位置參數,這些參數都必須是 Tensor。回傳的函數也會計算相對於這些參數的導數。

  • has_aux (bool) – 一個旗標,指示 func 回傳一個 (output, aux) 元組,其中第一個元素是要微分的函數的輸出,第二個元素是不會被微分的其他輔助物件。預設值:False。

回傳

回傳一個 (output, vjp_fn) 元組,包含將 func 應用於 primals 的輸出,以及一個計算 func 相對於所有 primals 的 vjp 的函數,計算時會使用傳遞給回傳函數的 cotangents。如果 has_aux is True,則改為回傳一個 (output, vjp_fn, aux) 元組。回傳的 vjp_fn 函數將回傳一個包含每個 VJP 的元組。

在簡單的情況下使用時,vjp() 的行為與 grad() 相同。

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))

但是,vjp() 可以透過傳入每個輸出的 cotangents 來支援具有多個輸出的函數。

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 甚至可以支援輸出為 Python 結構的情況。

>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 回傳的函數將會計算相對於每個 primals 的偏導數。

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primalsf 的位置參數。所有 kwargs 都使用它們的預設值。

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

注意

將 PyTorch torch.no_gradvjp 一起使用。情況 1:在函數內部使用 torch.no_grad

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在這種情況下,vjp(f)(x) 將尊重內部的 torch.no_grad

情況 2:在 torch.no_grad 上下文管理器內使用 vjp

>>> with torch.no_grad():
>>>     vjp(f)(x)

在這種情況下,vjp 將尊重內部的 torch.no_grad,但不尊重外部的。這是因為 vjp 是一個「函數轉換」:其結果不應取決於 f 外部的上下文管理器的結果。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學課程

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源