使用 autograd.Function 擴展 torch.func¶
所以您想要搭配 torch.autograd.Function
使用 torch.func
轉換,例如 torch.vmap()
、torch.func.grad()
等等。
主要有兩種使用案例
您希望呼叫不包含 PyTorch 運算的程式碼,並讓其與函數轉換搭配使用。也就是說,
torch.autograd.Function
的 forward/backward/etc 會呼叫來自其他系統(如 C++、CUDA、numpy)的函數。您希望指定自訂梯度規則,例如 JAX 的 custom_vjp/custom_jvp
PyTorch 將這兩個概念結合到 torch.autograd.Function
中。
基本用法¶
本指南假設您已熟悉擴展 torch.autograd,其中說明了如何使用 torch.autograd.Function
。
torch.autograd.Function
可以具有接受 ctx 物件的 forward()
,或者它可以具有獨立的 forward()
(不接受 ctx
) 和修改 ctx
物件的 setup_context()
靜態方法。
函數轉換僅支援後者
forward()
是執行運算的程式碼,它不應接受ctx
物件。setup_context(ctx, inputs, output)
是您可以在ctx
上呼叫方法的程式碼。 您應該在此處儲存用於反向傳播的張量 (透過呼叫ctx.save_for_backward(*tensors)
),或儲存非張量 (透過將它們分配給ctx
物件)。
由於 setup_context()
僅接受 inputs
和 output
,因此可以儲存的唯一量是輸入或輸出中的物件(例如張量)或從它們衍生的量(例如 Tensor.shape
)。 如果您希望從 Function.forward()
儲存非輸入的中間激活值以進行反向傳播,則需要將其作為 forward()
的輸出返回,以便將其傳遞到 setup_context()
。
取決於轉換,
為了支援反向模式 AD (
torch.func.grad()
,torch.func.vjp()
),torch.autograd.Function
需要backward()
靜態方法。為了支援
torch.vmap()
,torch.autograd.Function
需要vmap()
靜態方法。為了支援
torch.func.jvp()
,torch.autograd.Function
需要jvp()
靜態方法。為了支援轉換的組合 (例如
torch.func.jacrev()
,torch.func.jacfwd()
,torch.func.hessian()
) – 您可能需要上述多個。
為了使 torch.autograd.Function
可以與函數轉換任意組合,我們建議除了 forward()
和 setup_context()
之外的所有其他靜態方法都必須是可轉換的:也就是說,它們必須僅包含 PyTorch 運算符或呼叫其他 torch.autograd.Function
(可以呼叫到 C++/CUDA/etc)。
讓我們來看一些常見用例的範例。
範例 1:autograd.Function 呼叫到另一個系統¶
常見的情況是具有 forward() 和 backward() 的 torch.autograd.Function
呼叫到另一個系統 (例如 C++、CUDA、numpy、triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
現在,為了更容易使用 NumpySort
(以隱藏我們作為輸出返回的中間值,以及允許預設參數和關鍵字參數),我們建立一個呼叫它的新函數
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
這是一個健全性檢查
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
範例 2:autograd.Function 指定自定義梯度規則¶
另一個常見的情況是使用 PyTorch 運算實現的 torch.autograd.Function
。 PyTorch 能夠自動計算 PyTorch 運算的梯度,但我們可能希望自定義梯度的計算方式。 我們可能需要與 PyTorch 提供的自定義 backward 不同的原因有
提高數值穩定性
變更反向傳播的效能特性
變更邊緣案例的處理方式 (例如 nans, inf)
修改梯度 (例如梯度裁剪)
以下是一個 torch.autograd.Function
的範例,用於函數 y = x ** 3
,我們變更了效能特性 (一些通常在反向傳播期間發生的計算,即計算 dx,在正向傳播期間發生)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
現在,為了更容易使用 NumpySort
(並隱藏我們作為輸出返回的中間值),我們建立一個新的函數來調用它
def my_cube(x):
result, _ = MyCube.apply(x)
return result
這是一個用於計算二階梯度的健全性檢查
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制與注意事項¶
警告
請仔細閱讀 torch.autograd.Function
與 torch.func 轉換的這些限制。 我們無法捕捉到許多這些情況並優雅地報錯,因此它們會導致未定義的行為。
請不要將正在轉換、具有 requires_grad=True 或雙重張量的張量捕獲到 torch.autograd.Function
的方法中。 要完全安全的方式是確保在 torch.autograd.Function
的任何方法內使用的唯一張量必須直接作為輸入傳遞 (或透過 ctx 物件),而不是來自 torch.autograd.Function
外部。
torch.autograd.Function
不處理 pytree 中的張量(可能包含或不包含張量的任意巢狀 Python 資料結構)。 為了讓 autograd 追蹤這些張量,它們必須直接作為參數傳遞給 torch.autograd.Function
。 這與 jax.{custom_vjp, custom_jvp} 不同,後者確實接受 pytree。
請僅使用 save_for_backward()
或 save_for_forward()
來儲存張量。 請勿將張量或張量集合直接分配到 ctx 物件上 - 這些張量將不會被追蹤
torch.vmap()
支援¶
若要將 torch.autograd.Function
與 torch.vmap()
一起使用,您必須
提供一個
vmap()
staticmethod,告訴我們torch.autograd.Function
在torch.vmap()
下的行為要求我們透過設定
generate_vmap_rule=True
來自動產生它。
自動產生 vmap 規則¶
如果您的 torch.autograd.Function
符合以下額外限制,那麼我們可以為它產生 vmap 規則。 如果它不符合限制,或者您想要在 vmap 下的自訂行為,請手動定義 vmap staticmethod (請參閱下一節)。
警告
我們不容易檢查以下限制並優雅地報錯。 違反限制可能會導致未定義的行為。
torch.autograd.Function
的forward()
,backward()
(如果存在) 和jvp()
(如果存在) staticmethod 必須可透過torch.vmap()
轉換。 也就是說,它們必須僅包含 PyTorch 操作 (而不是例如 NumPy 或自訂 CUDA 核心)。
範例
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定義 vmap staticmethod¶
如果您的 torch.autograd.Function
呼叫另一個系統 (如 NumPy, C++, CUDA, triton),那麼為了使其與 torch.vmap()
或使用它的轉換一起工作,您需要手動定義一個 vmap()
staticmethod。
根據您想要使用的轉換和您的使用案例,您可能不需要將 vmap()
staticmethod 新增到您的所有 torch.autograd.Function
例如,
torch.func.jacrev()
在反向傳播上執行vmap()
。 因此,如果您只對使用torch.func.jacrev()
感興趣,則只需要backward()
staticmethod 是可 vmap 的。
我們建議確保您的所有 torch.autograd.Function
都支援 torch.vmap()
,尤其是在您編寫第三方函式庫,且希望您的 torch.autograd.Function
能與 torch.func()
轉換的所有組合搭配使用時。
從概念上講,vmap 靜態方法負責定義 forward()
在 torch.vmap()
下的行為方式。也就是說,它定義了如何轉換 forward()
以在具有額外維度(正在進行 vmap 的維度)的輸入上執行。這與 torch.vmap()
在 PyTorch 運算上的實作方式類似:對於每個運算,我們定義一個 vmap 規則(有時也稱為「批次規則」)。
以下是如何定義 vmap()
靜態方法
簽名為
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
與forward()
的 args 相同。vmap 靜態方法負責定義
forward()
在torch.vmap()
下的行為方式。也就是說,給定具有額外維度(由in_dims
指定)的輸入,我們該如何計算forward()
的批次版本?對於
args
中的每個 arg,in_dims
都有一個對應的Optional[int]
。如果 arg 不是 Tensor,或者 arg 沒有被 vmap,則為None
,否則,它是一個整數,指定 Tensor 的哪個維度正在被 vmap。info
是一組可能有用的額外元數據:info.batch_size
指定正在進行 vmap 的維度的大小,而info.randomness
是傳遞給torch.vmap()
的randomness
選項。vmap 靜態方法的傳回值是一個
(output, out_dims)
元組。與in_dims
類似,out_dims
應與output
具有相同的結構,並且每個輸出包含一個out_dim
,用於指定輸出是否具有 vmap 維度以及它所在的索引。
範例
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 靜態方法應旨在保留整個 Function
的語義。也就是說,(偽代碼)grad(vmap(MyFunc))
應該可以用 grad(map(MyFunc))
替換。
如果您的 autograd.Function 在反向傳播中具有任何自訂行為,請記住這一點。
注意
對於 PyTorch 可以通過 generate_vmap_rule=True
生成 vmap 規則的 Function
,編寫自訂 vmap 靜態方法是一個合法的用例。如果您生成的 vmap 規則沒有您要尋找的語義,您可能希望這樣做。
torch.func.jvp()
支援¶
為了支援正向模式 AD,torch.autograd.Function
必須具有 jvp()
靜態方法。請參閱 正向模式 AD 以取得詳細資訊。