functorch.vmap¶
-
functorch.
vmap
(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)[來源]¶ vmap 是一個向量化映射函數;
vmap(func)
會回傳一個新的函數,該函數會將func
映射到輸入的某個維度上。語義上,vmap 將映射推送到func
所呼叫的 PyTorch 操作中,有效地將這些操作向量化。vmap 對於處理批次維度非常有用:您可以撰寫一個作用於單個範例的函數
func
,然後使用vmap(func)
將其提升為可以處理一批範例的函數。vmap 與 autograd 組合使用時,還可以用於計算批次梯度。注意事項
為了方便起見,
torch.vmap()
與torch.func.vmap()
為別名。您可以使用任何一個。- 參數
func (函數) – 接收一個或多個引數的 Python 函數。必須回傳一個或多個張量。
in_dims (整數 或 巢狀結構) – 指定要映射的輸入維度。
in_dims
的結構應該與輸入的結構相同。如果特定輸入的in_dim
為 None,則表示沒有映射維度。預設值:0。out_dims (整數 或 元組[整數]) – 指定映射維度應該出現在輸出的哪個位置。如果
out_dims
是一個元組,則它應該為每個輸出包含一個元素。預設值:0。randomness (字串) – 指定此 vmap 中的隨機性在不同批次之間應該相同還是不同。如果為「different」,則每個批次的隨機性將會不同。如果為「same」,則隨機性在不同批次之間將會相同。如果為「error」,則任何對隨機函數的呼叫都會產生錯誤。預設值:「error」。警告:此旗標僅適用於 PyTorch 的隨機操作,不適用於 Python 的 random 模組或 numpy 的隨機性。
chunk_size (None 或 整數) – 如果為 None(預設值),則對輸入套用單個 vmap。如果不是 None,則一次計算
chunk_size
個樣本的 vmap。請注意,chunk_size=1
等同於使用 for 迴圈計算 vmap。如果您在計算 vmap 時遇到記憶體問題,請嘗試使用非 None 的 chunk_size。
- 回傳值
返回一個新的「批次化」函數。它接受與
func
相同的輸入,但每個輸入在in_dims
指定的索引處會增加一個維度。它返回與func
相同的輸出,但每個輸出在out_dims
指定的索引處會增加一個維度。
使用
vmap()
的一個例子是計算批次化的點積。PyTorch 沒有提供批次化的torch.dot
API;与其在文件中徒勞無功地搜尋,不如使用vmap()
來建構一個新的函數。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()
可以幫助隱藏批次維度,從而簡化模型的編寫體驗。>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples)
vmap()
也可以幫助向量化以前難以或不可能批次化的計算。一個例子是高階梯度計算。PyTorch 的自動微分引擎計算 vjps(向量-雅可比乘積)。計算函數 f: R^N -> R^N 的完整雅可比矩陣通常需要 N 次呼叫autograd.grad
,每個雅可比行一次。使用vmap()
,我們可以向量化整個計算,在一次呼叫autograd.grad
中計算雅可比矩陣。>>> # Setup >>> N = 5 >>> f = lambda x: x ** 2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N)
vmap()
也可以巢狀使用,產生具有多個批次維度的輸出。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
如果輸入不是沿著第一個維度批次化的,
in_dims
會指定每個輸入沿著哪個維度批次化。>>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
如果有多個輸入,每個輸入都沿著不同的維度批次化,則
in_dims
必須是一個元組,其中包含每個輸入的批次維度。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
如果輸入是一個 Python 結構,
in_dims
必須是一個包含與輸入形狀匹配的結構的元組。>>> f = lambda dict: torch.dot(dict['x'], dict['y']) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {'x': x, 'y': y} >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) >>> batched_dot(input)
預設情況下,輸出沿著第一個維度批次化。但是,可以使用
out_dims
將其沿著任何維度批次化。>>> f = lambda x: x ** 2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
對於任何使用 kwargs 的函數,返回的函數不會批次化 kwargs,但會接受 kwargs。
>>> x = torch.randn([2, 5]) >>> def fn(x, scale=4.): >>> return x * scale >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
注意事項
vmap 不提供通用的自動批次化,也不處理可變長度序列。
警告
我們已將 functorch 整合到 PyTorch 中。作為整合的最後一步,functorch.vmap 已在 PyTorch 2.0 中棄用,並將在 PyTorch >= 2.3 的未來版本中刪除。請改用 torch.vmap;有關更多詳細資訊,請參閱 PyTorch 2.0 版本說明和/或 torch.func 遷移指南 https://pytorch.dev.org.tw/docs/master/func.migrating.html