注意
點擊這裡下載完整的範例程式碼
torch.vmap¶
創建於:2020 年 10 月 26 日 | 最後更新:2021 年 9 月 01 日 | 最後驗證:未驗證
本教學課程介紹 torch.vmap,PyTorch 運算的自動向量化器。 torch.vmap 是一個原型功能,無法處理許多用例;但是,我們希望收集用例,以作為設計的參考。 如果您正在考慮使用 torch.vmap,或認為它對於某些事情來說非常酷,請透過 https://github.com/pytorch/pytorch/issues/42368 與我們聯繫。
那麼,什麼是 vmap?¶
vmap 是一個高階函數。 它接受一個函數 func 並返回一個新函數,該函數將 func 映射到輸入的某些維度上。 它深受 JAX 的 vmap 的啟發。
從語義上講,vmap 將 “map” 推送到 func 呼叫的 PyTorch 運算中,從而有效地向量化這些運算。
import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap
vmap 的第一個用例是簡化程式碼中批次維度的處理。 可以編寫一個在範例上運行的函數 func,然後使用 vmap(func) 將其提升為可以處理多批範例的函數。 但是,func 受到許多限制
它必須是函數式的(不能在其中改變 Python 數據結構),但 PyTorch 的原地運算除外。
範例批次必須作為張量提供。 這意味著 vmap 無法開箱即用地處理可變長度的序列。
使用 vmap 的一個範例是計算批次的點積。 PyTorch 沒有提供批次的 torch.dot API;與其徒勞地翻閱文檔,不如使用 vmap 來建構一個新函數
torch.dot # [D], [D] -> []
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)
vmap 有助於隱藏批次維度,從而簡化模型撰寫體驗。
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
# Very simple linear model with activation
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)
vmap 也可以幫助向量化以前難以或不可能批次處理的計算。 這將我們帶到第二個用例:批次梯度計算。
PyTorch 自動微分引擎計算 vjps(向量-雅可比積)。 使用 vmap,我們可以計算(批次向量)- 雅可比積。
這方面的一個範例是計算完整的雅可比矩陣(這也可以應用於計算完整的 Hessian 矩陣)。 計算某個函數 f: R^N -> R^N 的完整雅可比矩陣通常需要對 autograd.grad 進行 N 次呼叫,每次呼叫對應一個雅可比行。
# Setup
N = 5
def f(x):
return x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)
# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
return torch.autograd.grad(y, x, v)[0]
jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)
vmap 的第三個主要用例是計算每個樣本的梯度。 這是 vmap 原型目前無法高效處理的事情。 我們不確定計算每個樣本的梯度應該使用什麼 API,但如果您有任何想法,請在 https://github.com/pytorch/pytorch/issues/7786 中發表評論。
def model(sample, weight):
# do something...
return torch.dot(sample, weight)
def grad_sample(sample):
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.
# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)
腳本總運行時間: ( 0 分鐘 0.000 秒)