雅可比矩陣、海森矩陣、hvp、vhp 等:組合 functorch 變換¶
計算雅可比矩陣或海森矩陣在許多非傳統的深度學習模型中都很有用。使用標準的自動微分系統(如 PyTorch Autograd)很難(或很煩人)有效地計算這些量;functorch 提供了有效計算各種高階自動微分量的方法。
計算雅可比矩陣¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
讓我們從一個我們想要計算其雅可比矩陣的函數開始。這是一個帶有非線性激活函數的簡單線性函數。
def predict(weight, bias, x):
return F.linear(x, weight, bias).tanh()
讓我們添加一些虛擬數據:權重、偏差和特徵向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
讓我們將 predict
視為將輸入 x
從 \(R^D -> R^D\) 映射的函數。PyTorch Autograd 計算向量-雅可比矩陣乘積。為了計算這個 \(R^D -> R^D\) 函數的完整雅可比矩陣,我們必須每次使用不同的單位向量逐行計算它。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
我們可以使用 vmap 來擺脫 for 迴圈並向量化計算,而不是逐行計算雅可比矩陣。我們不能直接將 vmap 應用於 PyTorch Autograd;相反,functorch 提供了一個 vjp 變換
from functorch import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# lets confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在未來的教學中,反向模式自動微分和 vmap 的組合將為我們提供每個樣本的梯度。在本教學中,反向模式自動微分和 vmap 的組合為我們提供了雅可比矩陣計算!vmap 和自動微分變換的各種組合可以為我們提供不同的有趣量。
functorch 提供了 jacrev 作為一個方便的函數,它執行 vmap-vjp 組合來計算雅可比矩陣。jacrev 接受一個 argnums 參數,該參數說明我們要計算哪個參數的雅可比矩陣。
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# confirm
assert torch.allclose(ft_jacobian, jacobian)
讓我們比較兩種計算雅可比矩陣方法的效能。functorch 版本的速度要快得多(而且輸出越多速度就越快)。
一般來說,我們預計通過 vmap 進行向量化可以幫助消除開銷並更好地利用您的硬體。
Vmap 通過將外部迴圈推送到函數的原始操作中來實現這種魔法,以便獲得更好的效能。
讓我們建立一個快速函數來評估效能並處理微秒和毫秒的測量值
def get_perf(first, first_descriptor, second, second_descriptor):
""" takes torch.benchmark objects and compares delta of second vs first. """
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然後使用我們的 get_perf 函數執行效能比較
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a911b350>
compute_jac(xp)
2.25 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a6a99d50>
jacrev(predict, argnums=2)(weight, bias, x)
884.34 us
1 measurement, 500 runs , 1 thread
讓我們使用我們的 get_perf 函數對上述內容進行相對效能比較
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap");
Performance delta: 60.7170 percent improvement with vmap
此外,將問題反過來並說我們想要計算模型參數(權重、偏差)的雅可比矩陣,而不是輸入,這也很容易。
# note the change in input via argnums params of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩陣 (jacrev) vs 正向模式雅可比矩陣 (jacfwd)¶
我們提供了兩個計算雅可比矩陣的 API:jacrev 和 jacfwd
jacrev 使用反向模式自動微分。正如您在上面看到的,它是我們的 vjp 和 vmap 變換的組合。
jacfwd 使用正向模式自動微分。它被實現為我們的 jvp 和 vmap 變換的組合。
jacfwd 和 jacrev 可以互相替換,但它們具有不同的效能特徵。
作為一般經驗法則,如果您要計算 \(𝑅^N \to R^M\) 函數的雅可比矩陣,並且輸出遠遠多於輸入(即 \(M > N\)),則首選 jacfwd,否則使用 jacrev。此規則也有例外,但以下是對此規則的非嚴格論證
在反向模式自動微分中,我們逐行計算雅可比矩陣,而在正向模式自動微分(計算雅可比矩陣-向量乘積)中,我們逐列計算它。雅可比矩陣有 M 行 N 列,所以如果它在某個方向上更高或更寬,我們可能更喜歡處理行數或列數更少的方法。
from functorch import jacrev, jacfwd
首先,讓我們用輸入多於輸出的情況進行基準測試
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider...here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d792d0>
jacfwd(predict, argnums=2)(weight, bias, x)
1.32 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a4dee450>
jacrev(predict, argnums=2)(weight, bias, x)
12.46 ms
1 measurement, 500 runs , 1 thread
然後進行相對基準測試
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 842.8274 percent improvement with jacrev
現在反過來 - 輸出 (M) 多於輸入 (N)
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d64790>
jacfwd(predict, argnums=2)(weight, bias, x)
7.99 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d67b50>
jacrev(predict, argnums=2)(weight, bias, x)
1.09 ms
1 measurement, 500 runs , 1 thread
和相對效能比較
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 635.2095 percent improvement with jacfwd
使用 functorch.hessian 計算海森矩陣¶
我們提供了一個方便的 API 來計算海森矩陣:functorch.hessian
。海森矩陣是雅可比矩陣的雅可比矩陣(或偏導數的偏導數,又稱二階導數)。
這表明可以組合 functorch 的雅可比矩陣變換來計算海森矩陣。事實上,在底層,hessian(f)
就是 jacfwd(jacrev(f))
。
注意:為了提高效能:根據您的模型,您可能也想使用 jacfwd(jacfwd(f))
或 jacrev(jacrev(f))
來計算海森矩陣,利用上面關於更寬 vs 更高的矩陣的經驗法則。
from functorch import hessian
# lets reduce the size in order not to blow out colab. Hessians require significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
讓我們驗證無論是使用 hessian api 還是使用 jacfwd(jacfwd()),我們都得到了相同的結果
torch.allclose(hess_api, hess_fwdfwd)
True
批次雅可比矩陣和批次海森矩陣¶
在上面的例子中,我們一直在使用單個特徵向量。在某些情況下,您可能想要計算一批輸出相對於一批輸入的雅可比矩陣。也就是說,給定形狀為 (B, N)
的一批輸入和一個從 \(R^N \to R^M\) 的函數,我們想要一個形狀為 (B, M, N)
的雅可比矩陣。
最簡單的方法是使用 vmap
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
weight shape = torch.Size([33, 31])
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
如果您有一個從 (B, N) -> (B, M) 的函數,並且確定每個輸入都產生一個獨立的輸出,那麼有時也可以通過對輸出求和然後計算該函數的雅可比矩陣來做到這一點,而無需使用 vmap
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果您有一個從 \(𝑅^𝑁 \to 𝑅^𝑀\) 的函數,但輸入是批次的,您可以將 vmap 與 jacrev 組合起來計算批次雅可比矩陣
最後,批次海森矩陣的計算方法類似。最容易想到的是使用 vmap 對海森矩陣計算進行批次處理,但在某些情況下,求和技巧也適用。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
計算海森矩陣-向量乘積¶
計算海森矩陣-向量乘積 (hvp) 的天真方法是實例化完整的海森矩陣並與向量執行點積。我們可以做得更好:事實證明我們不需要實例化完整的海森矩陣來做到這一點。我們將介紹兩種(許多種中的兩種)不同的策略來計算海森矩陣-向量乘積
將反向模式自動微分與反向模式自動微分組合
將反向模式自動微分與正向模式自動微分組合
與反向模式與反向模式相比,將反向模式自動微分與正向模式自動微分組合通常是計算 hvp 的更節省記憶體的方法,因為正向模式自動微分不需要構建 Autograd 圖並保存中間結果以供反向傳播
from functorch import jvp, grad, vjp
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
以下是一些用法示例。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 正向自動微分沒有覆蓋您的操作,那麼我們可以將反向模式自動微分與反向模式自動微分組合
def hvp_revrev(f, primals, tangents):
_, vjp_fn = vjp(grad(f), *primals)
return vjp_fn(*tangents)
result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
assert torch.allclose(result, result_hvp_revrev[0])