捷徑

逐樣本梯度

建立於:2023 年 3 月 15 日 | 最後更新:2024 年 4 月 24 日 | 最後驗證:2024 年 11 月 05 日

這是什麼?

逐樣本梯度計算是計算資料批次中每個樣本的梯度。它在差分隱私、元學習和優化研究中是一個有用的量。

注意

本教程需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple CNN and loss function:

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

讓我們產生一批虛擬資料,並假裝我們正在處理 MNIST 資料集。虛擬圖像為 28 x 28,我們使用大小為 64 的小批量。

device = 'cuda'

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

在常規模型訓練中,會將小批量傳遞給模型,然後呼叫 .backward() 來計算梯度。這將產生整個小批次的「平均」梯度

model = SimpleCNN().to(device=device)
predictions = model(data)  # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss.backward()  # back propagate the 'average' gradient of this mini-batch

與上述方法相比,逐樣本梯度計算相當於

  • 對於資料的每個單獨樣本,執行正向傳播和反向傳播以獲得單獨的(逐樣本)梯度。

def compute_grad(sample, target):
    sample = sample.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = loss_fn(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


def compute_sample_grads(data, targets):
    """ manually process each sample with per sample gradient """
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

per_sample_grads = compute_sample_grads(data, targets)

sample_grads[0] 是 model.conv1.weight 的逐樣本梯度。model.conv1.weight.shape[32, 1, 3, 3];請注意,批次中每個樣本都有一個梯度,總共 64 個。

print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])

逐樣本梯度,使用函數轉換的有效方式

我們可以透過使用函數轉換有效地計算逐樣本梯度。

torch.func 函數轉換 API 會轉換函數。我們的策略是定義一個計算損失的函數,然後應用轉換來建構一個計算逐樣本梯度的函數。

我們將使用 torch.func.functional_call 函數將 nn.Module 視為函數。

首先,讓我們從 model 中提取狀態到兩個字典中,參數和緩衝區。我們將分離它們,因為我們不會使用常規的 PyTorch autograd(例如 Tensor.backward()、torch.autograd.grad)。

from torch.func import functional_call, vmap, grad

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

接下來,讓我們定義一個函數,用於計算給定單個輸入而不是批量輸入的模型的損失。重要的是,此函數接受參數、輸入和目標,因為我們將轉換它們。

注意 - 由於模型最初是為處理批次而編寫的,我們將使用 torch.unsqueeze 來新增批次維度。

def compute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = functional_call(model, (params, buffers), (batch,))
    loss = loss_fn(predictions, targets)
    return loss

現在,讓我們使用 grad 轉換來建立一個新的函數,該函數計算關於 compute_loss 的第一個參數(即 params)的梯度。

ft_compute_grad = grad(compute_loss)

ft_compute_grad 函數計算單個(樣本、目標)對的梯度。我們可以使用 vmap 使其計算整個樣本和目標批次的梯度。請注意,in_dims=(None, None, 0, 0),因為我們希望將 ft_compute_grad 映射到資料和目標的第 0 維度上,並對每個使用相同的 params 和緩衝區。

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

最後,讓我們使用我們的轉換函數來計算逐樣本梯度

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

我們可以仔細檢查使用 gradvmap 的結果是否與手動處理每個結果的結果相符

for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

快速提醒:vmap 能夠轉換的函式類型有一些限制。最適合轉換的函式是純函式:輸出僅由輸入決定的函式,並且沒有副作用(例如,變異)。vmap 無法處理任意 Python 資料結構的變異,但它可以處理許多就地 PyTorch 運算。

效能比較

想知道 vmap 的效能如何嗎?

目前,在較新的 GPU 上(例如 A100 (Ampere))可以獲得最佳結果,在這個範例中,我們已經看到高達 25 倍的加速,但這裡有一些在我們的建置機器上的結果

def get_perf(first, first_descriptor, second, second_descriptor):
    """takes torch.benchmark objects and compares delta of second vs first."""
    second_res = second.times[0]
    first_res = first.times[0]

    gain = (first_res-second_res)/first_res
    if gain < 0: gain *=-1
    final_gain = gain*100

    print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")

from torch.utils.benchmark import Timer

without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)

print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')

get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7ff794a9eb60>
compute_sample_grads(data, targets)
  98.50 ms
  1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7ff6fd5efee0>
ft_compute_sample_grad(params, buffers, data, targets)
  8.63 ms
  1 measurement, 100 runs , 1 thread
Performance delta: 1040.9007 percent improvement with vmap

還有其他優化的解決方案(例如在 https://github.com/pytorch/opacus 中)可以在 PyTorch 中計算每個樣本的梯度,它們的效能也優於簡單的方法。但 vmapgrad 的組合能提供良好的加速,這很酷。

通常,使用 vmap 進行向量化應該比在 for 迴圈中運行函式更快,並且可以與手動批次處理相媲美。但是,也存在一些例外情況,例如,如果我們沒有為特定操作實作 vmap 規則,或者如果底層核心沒有針對較舊的硬體(GPU)進行優化。如果您發現任何這些情況,請在 GitHub 上提出 issue 告訴我們。

指令碼總執行時間: ( 0 分鐘 11.834 秒)

由 Sphinx-Gallery 產生圖庫

文件

取得 PyTorch 完整的開發者文件

檢視文件

教程

取得為初學者和高級開發者提供的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源