torch.autograd.Function.vmap¶
- static Function.vmap(info, in_dims, *args)[原始碼][原始碼]¶
定義此 autograd.Function 在
torch.vmap()
下的行為。對於一個
torch.autograd.Function()
為了支援torch.vmap()
,你必須覆寫這個靜態方法,或者設定generate_vmap_rule
為True
(你不能同時做這兩件事)。如果你選擇覆寫這個靜態方法:它必須接受
一個
info
物件作為第一個參數。info.batch_size
指定正在進行 vmap 的維度的大小,而info.randomness
是傳遞給torch.vmap()
的隨機性選項。一個
in_dims
元組作為第二個參數。對於args
中的每個 arg,in_dims
都有一個對應的Optional[int]
。如果 arg 不是 Tensor,或者 arg 沒有被 vmap,則它是None
,否則,它是一個整數,指定 Tensor 的哪個維度正在被 vmap。*args
,它與傳遞給forward()
的 args 相同。
vmap 靜態方法的返回值是一個
(output, out_dims)
的元組。與in_dims
類似,out_dims
應該與output
具有相同的結構,並且每個輸出包含一個out_dim
,用於指定輸出是否具有 vmap 維度以及它在哪個索引中。請參閱 使用 autograd.Function 擴充 torch.func 以獲取更多詳細訊息。