快捷方式

torch.vmap

創建於:2020 年 10 月 26 日 | 最後更新:2021 年 9 月 01 日 | 最後驗證:未驗證

本教學課程介紹 torch.vmap,PyTorch 運算的自動向量化器。 torch.vmap 是一個原型功能,無法處理許多用例;但是,我們希望收集用例,以作為設計的參考。 如果您正在考慮使用 torch.vmap,或認為它對於某些事情來說非常酷,請透過 https://github.com/pytorch/pytorch/issues/42368 與我們聯繫。

那麼,什麼是 vmap?

vmap 是一個高階函數。 它接受一個函數 func 並返回一個新函數,該函數將 func 映射到輸入的某些維度上。 它深受 JAX 的 vmap 的啟發。

從語義上講,vmap 將 “map” 推送到 func 呼叫的 PyTorch 運算中,從而有效地向量化這些運算。

import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap

vmap 的第一個用例是簡化程式碼中批次維度的處理。 可以編寫一個在範例上運行的函數 func,然後使用 vmap(func) 將其提升為可以處理多批範例的函數。 但是,func 受到許多限制

  • 它必須是函數式的(不能在其中改變 Python 數據結構),但 PyTorch 的原地運算除外。

  • 範例批次必須作為張量提供。 這意味著 vmap 無法開箱即用地處理可變長度的序列。

使用 vmap 的一個範例是計算批次的點積。 PyTorch 沒有提供批次的 torch.dot API;與其徒勞地翻閱文檔,不如使用 vmap 來建構一個新函數

torch.dot                            # [D], [D] -> []
batched_dot = torch.vmap(torch.dot)  # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)

vmap 有助於隱藏批次維度,從而簡化模型撰寫體驗。

batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
    # Very simple linear model with activation
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)

vmap 也可以幫助向量化以前難以或不可能批次處理的計算。 這將我們帶到第二個用例:批次梯度計算。

PyTorch 自動微分引擎計算 vjps(向量-雅可比積)。 使用 vmap,我們可以計算(批次向量)- 雅可比積。

這方面的一個範例是計算完整的雅可比矩陣(這也可以應用於計算完整的 Hessian 矩陣)。 計算某個函數 f: R^N -> R^N 的完整雅可比矩陣通常需要對 autograd.grad 進行 N 次呼叫,每次呼叫對應一個雅可比行。

# Setup
N = 5
def f(x):
    return x ** 2

x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)

# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
                 for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)

# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
    return torch.autograd.grad(y, x, v)[0]

jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)

vmap 的第三個主要用例是計算每個樣本的梯度。 這是 vmap 原型目前無法高效處理的事情。 我們不確定計算每個樣本的梯度應該使用什麼 API,但如果您有任何想法,請在 https://github.com/pytorch/pytorch/issues/7786 中發表評論。

def model(sample, weight):
    # do something...
    return torch.dot(sample, weight)

def grad_sample(sample):
    return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]

# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.

# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)

腳本總運行時間: ( 0 分鐘 0.000 秒)

由 Sphinx-Gallery 生成的圖庫

文件

存取 PyTorch 的全面開發者文檔

查看文檔

教學課程

獲取針對初學者和高級開發者的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源