捷徑

每個樣本梯度

Open In Colab

它是什麼?

每個樣本梯度的計算是指計算一批資料中每個樣本的梯度。它是差分隱私、元學習和最佳化研究中的一個有用量。

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

torch.manual_seed(0);
# Here's a simple CNN and loss function:

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        output = x
        return output

def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

讓我們產生一批虛擬資料,並假設我們正在處理 MNIST 資料集。

虛擬圖像的大小為 28x28,我們使用大小為 64 的迷你批次。

device = 'cuda'

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

在常規模型訓練中,會將迷你批次輸入模型,然後呼叫 .backward() 來計算梯度。這將產生整個迷你批次的「平均」梯度。

model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss.backward() # back propogate the 'average' gradient of this mini-batch

與上述方法相反,每個樣本梯度的計算是相當於

  • 針對每個單獨的資料樣本,執行一次正向和反向傳遞,以獲得單獨的(每個樣本)梯度。

def compute_grad(sample, target):
    
    sample = sample.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = loss_fn(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


def compute_sample_grads(data, targets):
    """ manually process each sample with per sample gradient """
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

per_sample_grads = compute_sample_grads(data, targets)

sample_grads[0] 是 model.conv1.weight 的每個樣本梯度。 model.conv1.weight.shape[32, 1, 3, 3];請注意,批次中每個樣本都有一個梯度,總共有 64 個。

print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])

使用 functorch 計算每個樣本梯度的*有效方法*

我們可以使用函數變換來有效地計算每個樣本的梯度。

首先,讓我們使用 functorch.make_functional_with_buffers 建立 model 的無狀態函數版本。

這將把狀態(參數)與模型分離,並將模型轉換為純函數。

from functorch import make_functional_with_buffers, vmap, grad

fmodel, params, buffers = make_functional_with_buffers(model)

讓我們回顧一下變化 - 首先,模型已經變成了無狀態的 FunctionalModuleWithBuffers。

fmodel
FunctionalModuleWithBuffers(
  (stateless_model): SimpleCNN(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (fc1): Linear(in_features=9216, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)

而模型參數現在獨立於模型存在,存儲為一個元組。

for x in params:
  print(f"{x.shape}")

print(f"\n{type(params)}")
torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([128, 9216])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])

<class 'tuple'>

接下來,讓我們定義一個函數來計算給定單個輸入而不是一批輸入的模型損失。重要的是,這個函數要接受參數、輸入和目標,因為我們將對它們進行變換。

注意 - 因為模型最初是寫成處理批次的,所以我們將使用 torch.unsqueeze 來添加批次維度。

def compute_loss_stateless_model (params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = loss_fn(predictions, targets)
    return loss

現在,讓我們使用 functorch 的 grad 來建立一個新的函數,用於計算相對於 compute_loss 的第一個參數(即參數)的梯度。

ft_compute_grad = grad(compute_loss_stateless_model)

ft_compute_grad 函數計算單個(樣本、目標)對的梯度。我們可以使用 vmap 來讓它計算整個批次樣本和目標的梯度。注意,in_dims=(None, None, 0, 0) 是因為我們希望將 ft_compute_grad 映射到資料和目標的第 0 維度,並為每個資料和目標使用相同的參數和緩衝區。

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

最後,讓我們使用變換後的函數來計算每個樣本的梯度。

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

簡要說明:vmap 可以變換的函數類型是有限制的。最適合變換的函數是純函數:輸出僅由輸入決定的函數,並且沒有副作用(例如突變)。vmap 無法處理任意 Python 資料結構的突變,但它可以處理許多就地 PyTorch 操作。

效能比較

好奇 vmap 的效能如何比較嗎?

目前,在較新的 GPU(例如 A100(Ampere))上獲得了最佳結果,在這個例子中,我們看到了高達 25 倍的速度提升,但這裡有一些在 Colab 中完成的結果。

def get_perf(first, first_descriptor, second, second_descriptor):
  """  takes torch.benchmark objects and compares delta of second vs first. """
  second_res = second.times[0]
  first_res = first.times[0]

  gain = (first_res-second_res)/first_res
  if gain < 0: gain *=-1 
  final_gain = gain*100

  print(f" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer

without_vmap = Timer( stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)

print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f71ac3f1850>
compute_sample_grads(data, targets)
  79.86 ms
  1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7143e26f10>
ft_compute_sample_grad(params, buffers, data, targets)
  12.93 ms
  1 measurement, 100 runs , 1 thread
get_perf(with_vmap_timing, "vmap", no_vmap_timing,"no vmap" )
 Performance delta: 517.5791 percent improvement with vmap 

在 PyTorch 中,還有其他優化的解決方案(例如在 https://github.com/pytorch/opacus 中)來計算每個樣本的梯度,這些解決方案的效能也優於原始方法。但是,組合 vmapgrad 能夠提供不錯的速度提升,這一點很酷。

一般來說,使用 vmap 進行向量化應該比在 for 迴圈中運行函數更快,並且與手動批次處理相比具有競爭力。但也有一些例外情況,例如,如果我們沒有為特定操作實現 vmap 規則,或者如果底層內核沒有針對舊硬體(GPU)進行優化。如果您遇到任何這些情況,請在我們的 GitHub 上提交 issue 告訴我們!

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源