捷徑

神經正切核

建立時間:2023 年 3 月 15 日 | 最後更新時間:2023 年 6 月 16 日 | 最後驗證時間:未驗證

神經正切核 (NTK) 是一個描述神經網路在訓練期間如何演變的核。 近年來,對其進行了大量研究。 本教學受到JAX 中 NTK 的實作的啟發 (詳情請參閱快速有限寬度神經正切核),示範如何使用 torch.func (PyTorch 的可組合函數轉換) 輕鬆計算此量。

注意

本教學需要 PyTorch 2.0.0 或更高版本。

設定

首先,進行一些設定。 讓我們定義一個簡單的 CNN,我們希望計算其 NTK。

import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, (3, 3))
        self.conv2 = nn.Conv2d(32, 32, (3, 3))
        self.conv3 = nn.Conv2d(32, 32, (3, 3))
        self.fc = nn.Linear(21632, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = x.relu()
        x = self.conv2(x)
        x = x.relu()
        x = self.conv3(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x

讓我們產生一些隨機資料

x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)

建立模型的功能版本

torch.func 轉換對函數進行操作。 特別是,為了計算 NTK,我們需要一個接受模型參數和單一輸入 (而不是一批輸入!) 並傳回單一輸出的函數。

我們將使用 torch.func.functional_call,這允許我們使用不同的參數/緩衝區呼叫 nn.Module,以協助完成第一步。

請記住,該模型最初是編寫為接受一批輸入資料點。 在我們的 CNN 範例中,沒有批次間操作。 也就是說,批次中的每個資料點都獨立於其他資料點。 考慮到這個假設,我們可以輕鬆地產生一個在單一資料點上評估模型的功能

net = CNN().to(device)

# Detaching the parameters because we won't be calling Tensor.backward().
params = {k: v.detach() for k, v in net.named_parameters()}

def fnet_single(params, x):
    return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)

計算 NTK:方法 1 (Jacobian 收縮)

我們已準備好計算經驗 NTK。 兩個資料點 \(x_1\)\(x_2\) 的經驗 NTK 定義為在 \(x_1\) 處評估的模型 Jacobian 與在 \(x_2\) 處評估的模型 Jacobian 之間的矩陣乘積

\[J_{net}(x_1) J_{net}^T(x_2)\]

在批次案例中,其中 \(x_1\) 是一批資料點,而 \(x_2\) 是一批資料點,那麼我們想要 \(x_1\)\(x_2\) 中所有資料點組合的 Jacobian 之間的矩陣乘積。

第一種方法包括執行此操作 - 計算兩個 Jacobian,並收縮它們。 以下是如何在批次案例中計算 NTK

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = jac1.values()
    jac1 = [j.flatten(2) for j in jac1]

    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = jac2.values()
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])

在某些情況下,您可能只想知道此量的對角線或追蹤,特別是如果您事先知道網路架構會產生 NTK,其中非對角線元素可以用零近似。 很容易調整上述函數來做到這一點

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = jac1.values()
    jac1 = [j.flatten(2) for j in jac1]

    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = jac2.values()
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        assert False

    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])

此方法的漸近時間複雜度為 \(N O [FP]\) (計算 Jacobian 的時間) + \(N^2 O^2 P\) (收縮 Jacobian 的時間),其中 \(N\)\(x_1\)\(x_2\) 的批次大小,\(O\) 是模型的輸出大小,\(P\) 是參數總數,而 \([FP]\) 是單次正向傳遞模型的成本。 詳情請參閱 快速有限寬度神經正切核的 3.2 節。

計算 NTK:方法 2 (NTK-向量乘積)

我們將討論的下一個方法是使用 NTK-向量乘積計算 NTK 的方法。

此方法將 NTK 重新公式化為應用於大小為 \(O\times O\) 的單位矩陣 \(I_O\) (其中 \(O\) 是模型的輸出大小) 的列的 NTK-向量乘積堆疊

\[J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \left[J_{net}(x_1) \left[J_{net}^T(x_2) e_o\right]\right]_{o=1}^{O},\]

其中 \(e_o\in \mathbb{R}^O\) 是單位矩陣 \(I_O\) 的列向量。

  • \(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。 我們可以使用向量-Jacobian 乘積來計算此值。

  • 現在,考慮 \(J_{net}(x_1) \textrm{vjp}_o\)。 這是 Jacobian-向量乘積!

  • 最後,我們可以使用 vmap 並行地在 \(I_O\) 的所有列 \(e_o\) 上執行上述計算。

這表示我們可以使用反向模式 AD (用於計算向量-Jacobian 乘積) 和正向模式 AD (用於計算 Jacobian-向量乘積) 的組合來計算 NTK。

讓我們編寫程式碼

def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
    def get_ntk(x1, x2):
        def func_x1(params):
            return func(params, x1)

        def func_x2(params):
            return func(params, x2)

        output, vjp_fn = vjp(func_x1, params)

        def get_ntk_slice(vec):
            # This computes ``vec @ J(x2).T``
            # `vec` is some unit vector (a single slice of the Identity matrix)
            vjps = vjp_fn(vec)
            # This computes ``J(X1) @ vjps``
            _, jvps = jvp(func_x2, (params,), vjps)
            return jvps

        # Here's our identity matrix
        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
        return vmap(get_ntk_slice)(basis)

    # ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
    # Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,
    # we actually wish to compute the NTK between every pair of data points
    # between {x1} and {x2}. That's what the ``vmaps`` here do.
    result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)

    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK', result)

# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
with torch.backends.cudnn.flags(allow_tf32=False):
    result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
    result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)

assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)

我們的 empirical_ntk_ntk_vps 程式碼看起來像是從上面的數學直接翻譯! 這展示了函數轉換的力量:祝您好運,嘗試僅使用 torch.autograd.grad 來編寫上述程式碼的有效版本。

此方法的漸近時間複雜度為 \(N^2 O [FP]\),其中 \(N\)\(x_1\)\(x_2\) 的批次大小,\(O\) 是模型的輸出大小,而 \([FP]\) 是單次模型前向傳遞的成本。因此,此方法比方法 1 Jacobian 收縮(\(N^2 O\) 而非 \(N O\))執行更多次網路前向傳遞,但完全避免了收縮成本(沒有 \(N^2 O^2 P\) 項,其中 \(P\) 是模型的參數總數)。因此,當 \(O P\) 相對於 \([FP]\) 較大時,例如具有多個輸出 \(O\) 的全連接(非卷積)模型,此方法更為可取。在記憶體方面,兩種方法應該相當。詳情請參閱快速有限寬度神經正切核 (Fast Finite Width Neural Tangent Kernel) 的第 3.3 節。

腳本總運行時間:( 0 分鐘 0.541 秒)

圖庫由 Sphinx-Gallery 產生


評價本教學

© 版權所有 2024, PyTorch。

使用 Sphinx 構建,主題由 theme 提供,由 Read the Docs 提供。

文件

獲取 PyTorch 的全面開發人員文檔

查看文檔

教學

獲取針對初學者和高級開發人員的深入教程

查看教程

資源

尋找開發資源並獲得您問題的解答

查看資源