快速鍵

functorch.vjp

functorch.vjp(func, *primals, has_aux=False)[source]

向量雅可比積的簡寫,傳回包含下列結果的元組:將 func 應用到 primals 的結果,以及一個函數,該函數在提供 cotangents 時,計算 func 對於 primals 的反向模式雅可比,再乘以 cotangents

參數
  • func (Callable) – 一個接收一或多個引數的 Python 函數。必須傳回一個或多個張量。

  • primals (Tensors) – 傳遞給 func 的位置引數,所有內容都必須是張量。傳回的函數也會計算關於這些引數的導數

  • has_aux (bool) – 旗標指示 func 傳回一個 (output, aux) 元組,其中第一個元素是函數輸出,準備用於區分,而第二個元素則是不會區分的其他輔助物件。預設值:False.

傳回

傳回一個 (output, vjp_fn) 元組,其中包含將 func 應用到 primals 的輸出,以及一個函數,該函數計算 func 對於所有 primals 的 VJP,使用傳遞給傳回函數的餘切。如果 has_aux is True,則傳回 (output, vjp_fn, aux) 元組。傳回的 vjp_fn 函數將傳回一個每個 VJP 的元組。

在簡單案例中使用時,vjp() 的行為與 grad() 相同

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))

然而,vjp() 可以支援有多個輸出的函數,方法是傳入每個輸出的餘切

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 甚至可以支援輸出為 Python 結構

>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 傳回的函數將計算關於每個 primals 的偏微分

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primalsf 的位置引數。所有關鍵字引數都會使用其預設值

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

備註

搭配 vjp 使用 PyTorch torch.no_grad。案例 1:在函數內使用 torch.no_grad

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在本例中,vjp(f)(x) 將會尊重內部的 torch.no_grad

案例 2:在 torch.no_grad 語法管理員中使用 vjp

>>> # xdoctest: +SKIP(failing)
>>> with torch.no_grad():
>>>     vjp(f)(x)

在本例中,vjp 將會尊重內部的 torch.no_grad,但不會尊重外部的。這是因為 vjp 是「函數轉換」:它的結果不應依賴於外部 f 的語法管理員的結果。

警告

我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.vjp 自 PyTorch 2.0 版開始已標示為不建議使用,且會在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.func.vjp;有關更多詳細資訊,請參閱 PyTorch 2.0 發行說明和/或 torch.func 移轉指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件說明

取得 PyTorch 的全面開發人員文件

查看文件

教學課程

取得初學者和進階開發人員的深入教學

查看教學

資源

找到開發資源,並找出您的疑問解答

查看資源