注意
點擊這裡下載完整範例程式碼
Jacobians, Hessians, hvp, vhp 及更多:組合函數轉換¶
建立於:2023 年 3 月 15 日 | 最後更新:2023 年 4 月 18 日 | 最後驗證:2024 年 11 月 05 日
計算 Jacobians 或 Hessians 在許多非傳統的深度學習模型中很有用。使用 PyTorch 常規的自動微分 API(Tensor.backward()
, torch.autograd.grad
)有效地計算這些量是很困難(或令人厭煩)的。PyTorch 受 JAX 啟發的 函數轉換 API 提供了有效計算各種高階自動微分量的方法。
注意
本教學需要 PyTorch 2.0.0 或更新版本。
計算 Jacobian¶
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
讓我們從一個想要計算 Jacobian 的函數開始。這是一個具有非線性激活的簡單線性函數。
讓我們添加一些虛擬資料:權重、偏差和特徵向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
讓我們將 predict
視為將輸入 x
從 \(R^D \to R^D\) 映射的函數。PyTorch Autograd 計算向量-Jacobian 乘積。為了計算這個 \(R^D \to R^D\) 函數的完整 Jacobian,我們必須逐行計算,每次使用不同的單位向量。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
我們可以不用逐行計算 Jacobian,而是使用 PyTorch 的 torch.vmap
函數轉換來消除 for 迴圈並向量化計算。我們無法直接將 vmap
應用於 torch.autograd.grad
;相反,PyTorch 提供了一個與 torch.vmap
組成的 torch.func.vjp
轉換
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在後面的教學中,反向模式 AD 和 vmap
的組合將為我們提供每個樣本的梯度。在本教學中,組合反向模式 AD 和 vmap
將為我們提供 Jacobian 計算! vmap
和自動微分轉換的各種組合可以為我們提供不同的有趣量。
PyTorch 提供了 torch.func.jacrev
作為執行 vmap-vjp
組合以計算 Jacobians 的便利函數。 jacrev
接受一個 argnums
參數,該參數說明我們希望計算哪個參數的 Jacobians。
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
讓我們比較一下計算 Jacobian 的兩種方法的效能。函數轉換版本快得多(並且輸出越多,速度越快)。
通常,我們期望通過 vmap
進行向量化可以幫助消除開銷並更好地利用您的硬體。
vmap
通過將外部迴圈向下推入函數的原始操作中來完成此魔術,以獲得更好的效能。
讓我們快速建立一個函數來評估效能並處理微秒和毫秒的測量
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然後執行效能比較
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd07017b370>
compute_jac(xp)
3.06 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd05ca88340>
jacrev(predict, argnums=2)(weight, bias, x)
772.86 us
1 measurement, 500 runs , 1 thread
讓我們使用我們的 get_perf
函數對以上內容進行相對效能比較
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 74.7730 percent improvement with vmap
此外,很容易翻轉問題並說我們想要計算模型參數(權重、偏差)而不是輸入的 Jacobians
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式 Jacobian (jacrev
) vs 前向模式 Jacobian (jacfwd
)¶
我們提供了兩個 API 來計算 Jacobians:jacrev
和 jacfwd
jacrev
使用反向模式 AD。正如您在上面看到的,它是我們的vjp
和vmap
轉換的組合。jacfwd
使用前向模式 AD。它實作方式是透過組合我們的jvp
和vmap
轉換。
jacfwd
和 jacrev
可以互相替換,但它們有不同的效能特性。
一般來說,如果您要計算 \(R^N \to R^M\) 函數的 Jacobian 矩陣,且輸出的數量遠大於輸入 (例如,\(M > N\)),那麼會優先選擇 jacfwd
,否則使用 jacrev
。 這個規則也有例外情況,以下提供一個不嚴謹的論證:
在反向模式 AD 中,我們逐列計算 Jacobian 矩陣,而在前向模式 AD 中 (計算 Jacobian 向量積),我們逐欄計算。 Jacobian 矩陣有 M 列和 N 欄,所以如果矩陣在某個方向上更高或更寬,我們可能會優先選擇處理較少列或欄的方法。
首先,我們用比輸出更多的輸入來進行基準測試
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fd07067c970>
jacfwd(predict, argnums=2)(weight, bias, x)
1.39 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fd05c9e64a0>
jacrev(predict, argnums=2)(weight, bias, x)
7.97 ms
1 measurement, 500 runs , 1 thread
然後進行相對基準測試
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 472.0687 percent improvement with jacrev
現在反過來 - 更多的輸出 (M) 比輸入 (N) 多
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fd05ca73cd0>
jacfwd(predict, argnums=2)(weight, bias, x)
6.60 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fd0a464fcd0>
jacrev(predict, argnums=2)(weight, bias, x)
882.64 us
1 measurement, 500 runs , 1 thread
以及相對效能比較
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 647.2890 percent improvement with jacfwd
使用 functorch.hessian 計算 Hessian 矩陣¶
我們提供了一個方便的 API 來計算 Hessian 矩陣:torch.func.hessian
。 Hessian 矩陣是 Jacobian 矩陣的 Jacobian 矩陣 (或偏導數的偏導數,也稱為二階導數)。
這表明可以直接組合 functorch Jacobian 轉換來計算 Hessian 矩陣。 事實上,在底層,hessian(f)
只是 jacfwd(jacrev(f))
。
注意:為了提高效能:根據您的模型,您可能還想使用 jacfwd(jacfwd(f))
或 jacrev(jacrev(f))
來計算 Hessian 矩陣,並利用上述關於較寬矩陣與較高矩陣的經驗法則。
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
讓我們驗證無論是使用 Hessian API 還是使用 jacfwd(jacfwd())
,我們都能得到相同的結果。
True
批次 Jacobian 矩陣和批次 Hessian 矩陣¶
在上面的例子中,我們一直使用單一的特徵向量。 在某些情況下,您可能想要計算一批輸出的 Jacobian 矩陣,這些輸出相對於一批輸入。 也就是說,給定形狀為 (B, N)
的一批輸入和一個從 \(R^N \to R^M\) 的函數,我們希望 Jacobian 矩陣的形狀為 (B, M, N)
。
最簡單的方法是使用 vmap
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果您有一個從 (B, N) -> (B, M) 的函數,並且確定每個輸入都會產生獨立的輸出,那麼有時也可以在不使用 vmap
的情況下,將輸出相加,然後計算該函數的 Jacobian 矩陣。
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果您有一個從 \(R^N \to R^M\) 的函數,但輸入是批次的,您可以將 vmap
與 jacrev
組合起來以計算批次的 Jacobian 矩陣。
最後,可以類似地計算批次 Hessian 矩陣。 最容易的方法是使用 vmap
來對 Hessian 矩陣計算進行批次處理,但在某些情況下,求和技巧也有效。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
計算 Hessian 向量積¶
計算 Hessian 向量積 (hvp) 的最簡單方法是將完整的 Hessian 矩陣實體化,並執行與向量的點積。 我們可以做得更好:事實證明,我們不需要將完整的 Hessian 矩陣實體化才能做到這一點。 我們將介紹兩種 (許多) 不同的策略來計算 Hessian 向量積: - 將反向模式 AD 與反向模式 AD 組合 - 將反向模式 AD 與前向模式 AD 組合
將反向模式 AD 與前向模式 AD 組合 (而不是反向模式與反向模式) 通常是計算 hvp 更節省記憶體的方式,因為前向模式 AD 不需要建構 Autograd 圖形並儲存中介值以進行反向傳播
以下是一些範例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 前向 AD 沒有涵蓋您的操作,那麼我們可以將反向模式 AD 與反向模式 AD 組合
腳本的總執行時間: ( 0 分鐘 11.637 秒)