神經正切核¶
神經正切核 (NTK) 是一種描述 神經網路在訓練過程中如何演變 的核。 近年來 有很多關於它的研究。本教學課程受到 JAX 中 NTK 的實現 的啟發(詳情請參閱 快速有限寬度神經正切核),演示了如何使用 functorch 輕鬆計算此數量。
設定¶
首先,進行一些設定。讓我們定義一個我們想要計算其 NTK 的簡單 CNN。
import torch
import torch.nn as nn
from functorch import make_functional, vmap, vjp, jvp, jacrev
device = 'cuda'
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, (3, 3))
self.conv2 = nn.Conv2d(32, 32, (3, 3))
self.conv3 = nn.Conv2d(32, 32, (3, 3))
self.fc = nn.Linear(21632, 10)
def forward(self, x):
x = self.conv1(x)
x = x.relu()
x = self.conv2(x)
x = x.relu()
x = self.conv3(x)
x = x.flatten(1)
x = self.fc(x)
return x
讓我們產生一些隨機資料
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)
建立模型的函數版本¶
functorch 變換對函數進行操作。特別是,為了計算 NTK,我們需要一個函數,它接受模型的參數和單個輸入(而不是一批輸入!),並返回單個輸出。
我們將使用 functorch 的 make_functional
來完成第一步。如果您的模組具有緩衝區,則您需要改用 make_functional_with_buffers
。
net = CNN().to(device)
fnet, params = make_functional(net)
請記住,模型最初是為了接受一批輸入資料點而編寫的。在我們的 CNN 示例中,沒有批次間操作。也就是說,批次中的每個資料點都獨立於其他資料點。考慮到這個假設,我們可以輕鬆地生成一個函數,用於評估單個資料點上的模型
def fnet_single(params, x):
return fnet(params, x.unsqueeze(0)).squeeze(0)
計算 NTK:方法 1(雅可比矩陣收縮)¶
我們準備好計算經驗 NTK。兩個資料點 \(x_1\) 和 \(x_2\) 的經驗 NTK 定義為在 \(x_1\) 處評估的模型的雅可比矩陣與在 \(x_2\) 處評估的模型的雅可比矩陣之間的矩陣乘積
在批次情況下,其中 \(x_1\) 是一批資料點,而 \(x_2\) 是一批資料點,那麼我們想要的是來自 \(x_1\) 的資料點與來自 \(x_2\) 的資料點的所有組合的雅可比矩陣之間的矩陣乘積。
第一種方法恰恰是這樣做的 - 計算兩個雅可比矩陣,並將它們收縮。以下是計算批次情況下 NTK 的方法
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])
在某些情況下,您可能只想獲得此數量的對角線或跡線,尤其是在您事先知道網路架構會導致非對角線元素可以近似為零的 NTK 時。很容易調整上述函數來做到這一點
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])
此方法的漸近時間複雜度為 \(N O [FP]\)(計算雅可比矩陣的時間) \( + N^2 O^2 P\)(收縮雅可比矩陣的時間),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批次大小,\(O\) 是模型的輸出大小,\(P\) 是參數總數,而 \([FP]\) 是模型單次正向傳遞的成本。詳情請參閱 快速有限寬度神經正切核 中的第 3.2 節。
計算 NTK:方法 2(NTK-向量積)¶
我們將討論的下一個方法是一種使用 NTK-向量積來計算 NTK 的方法。
此方法將 NTK 重新表述為應用於大小為 \(O\times O\) 的單位矩陣 \(I_O\) 的列的 NTK-向量積的堆疊(其中 \(O\) 是模型的輸出大小)
其中 \(e_o\in \mathbb{R}^O\) 是單位矩陣 \(I_O\) 的列向量。
令 \(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。我們可以使用向量-雅可比矩陣積來計算這一點。
現在,考慮 \(J_{net}(x_1) \textrm{vjp}_o\)。這是一個雅可比矩陣-向量積!
最後,我們可以使用
vmap
在 \(I_O\) 的所有列 \(e_o\) 上並行運行上述計算。
這表明我們可以使用反向模式 AD(用於計算向量-雅可比矩陣積)和正向模式 AD(用於計算雅可比矩陣-向量積)的組合來計算 NTK。
讓我們編寫程式碼
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes vec @ J(x2).T
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes J(X1) @ vjps
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
return vmap(get_ntk_slice)(basis)
# get_ntk(x1, x2) computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the vmaps here do.
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM', result)
if compute == 'diagonal':
return torch.einsum('NMKK->NMK', result)
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
我們用於 empirical_ntk_ntk_vps
的程式碼看起來像是對上述數學的直接翻譯!這展示了函數變換的強大功能:祝您好運,嘗試使用普通的 PyTorch 編寫上述程式碼的有效版本。
此方法的漸近時間複雜度為 \(N^2 O [FP]\),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批次大小,\(O\) 是模型的輸出大小,而 \([FP]\) 是模型單次正向傳遞的成本。因此,與方法 1(雅可比矩陣收縮)相比,此方法在網路中執行的正向傳遞次數更多(\(N^2 O\) 而不是 \(N O\)),但完全避免了收縮成本(沒有 \(N^2 O^2 P\) 項,其中 \(P\) 是模型參數的總數)。因此,當 \(O P\) 相對於 \([FP]\) 較大時,例如具有許多輸出 \(O\) 的全連接(非卷積)模型,此方法更佳。在記憶體方面,這兩種方法應該是相當的。詳情請參閱 快速有限寬度神經正切核 中的第 3.3 節。