捷徑

torch.func.vmap

torch.func.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)[原始碼]

vmap 是一個向量化映射 (vectorizing map);vmap(func) 返回一個新的函數,該函數將 func 映射到輸入的某些維度上。從語義上講,vmap 將映射推送到 func 呼叫的 PyTorch 運算中,有效地向量化了這些運算。

vmap 對於處理批次 (batch) 維度非常有用:可以編寫一個在範例上運行的函數 func,然後使用 vmap(func) 將其提升為可以接收多個範例批次的函數。當與 autograd 組合使用時,vmap 也可以用於計算批次梯度。

注意

為方便起見,torch.vmap()torch.func.vmap() 的別名。 請隨意使用其中一個。

參數
  • func (function) – 一個接受一個或多個參數的 Python 函數。 必須返回一個或多個 Tensor。

  • in_dims (int巢狀結構 (nested structure)) – 指定應映射輸入的哪個維度。in_dims 應具有與輸入類似的結構。 如果特定輸入的 in_dim 為 None,則表示沒有映射維度。 預設值:0。

  • out_dims (intTuple[int]) – 指定映射的維度應出現在輸出的哪個位置。 如果 out_dims 是一個 Tuple,則它應該每個輸出有一個元素。 預設值:0。

  • randomness (str) – 指定此 vmap 中的隨機性在批次之間應相同還是不同。 如果為 'different',則每個批次的隨機性將不同。 如果為 'same',則批次之間的隨機性將相同。 如果為 'error',則對隨機函數的任何呼叫都會出錯。 預設值:'error'。警告:此標誌僅適用於隨機 PyTorch 運算,不適用於 Python 的 random 模組或 numpy 隨機性。

  • chunk_size (Noneint) – 如果為 None (預設值),則對輸入應用單個 vmap。 如果不為 None,則一次計算 vmap chunk_size 個樣本。 請注意,chunk_size=1 等同於使用 for 迴圈計算 vmap。 如果您在計算 vmap 時遇到記憶體問題,請嘗試非 None 的 chunk_size。

返回

返回一個新的「批次化 (batched)」函數。 它採用與 func 相同的輸入,除了每個輸入在 in_dims 指定的索引處都有一個額外的維度。 它採用與 func 相同的輸出返回,除了每個輸出在 out_dims 指定的索引處都有一個額外的維度。

返回類型

可呼叫 (Callable)

使用 vmap() 的一個範例是計算批次化的點積。 PyTorch 沒有提供批次化的 torch.dot API; 與其徒勞無功地翻找文件,不如使用 vmap() 來建構一個新函數。

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)

vmap() 有助於隱藏批次維度,從而簡化模型編寫體驗。

>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>>     # Very simple linear model with activation
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)

vmap() 也有助於向量化以前難以或無法批次處理的計算。 一個範例是高階梯度計算。 PyTorch autograd 引擎計算 vjps (向量-雅可比行列式乘積)。 計算某個函數 f: R^N -> R^N 的完整雅可比行列式通常需要 N 次呼叫 autograd.grad,每次呼叫對應於雅可比行列式的一行。 使用 vmap(),我們可以向量化整個計算,並透過單次呼叫 autograd.grad 來計算雅可比行列式。

>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>>                  for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>>     return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)

vmap() 也可以巢狀使用,產生具有多個批次維度的輸出

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot))  # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]

如果輸入不是沿著第一個維度批次化,則 in_dims 指定每個輸入批次化所沿著的維度,如下所示

>>> torch.dot                            # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1)  # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)   # output is [5] instead of [2] if batched along the 0th dimension

如果有多個輸入,且每個輸入沿著不同的維度批次化,則 in_dims 必須是一個元組,其中包含每個輸入的批次維度,如下所示

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None))  # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None

如果輸入是一個 Python 結構,則 in_dims 必須是一個元組,其中包含一個與輸入形狀相符的結構

>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)

預設情況下,輸出沿著第一個維度批次化。 但是,可以使用 out_dims 沿著任何維度批次化

>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]

對於任何使用 kwargs 的函數,返回的函數將不會批次化 kwargs,但會接受 kwargs

>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>>   return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]

注意

vmap 不提供通用的自動批次處理 (autobatching) 或處理開箱即用的可變長度序列。

文件

存取 PyTorch 完整的開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您問題的解答

檢視資源