捷徑

torch.autograd.Function.vmap

static Function.vmap(info, in_dims, *args)[原始碼][原始碼]

定義此 autograd.Function 在 torch.vmap() 下的行為。

對於一個 torch.autograd.Function() 為了支援 torch.vmap(),你必須覆寫這個靜態方法,或者設定 generate_vmap_ruleTrue (你不能同時做這兩件事)。

如果你選擇覆寫這個靜態方法:它必須接受

  • 一個 info 物件作為第一個參數。info.batch_size 指定正在進行 vmap 的維度的大小,而 info.randomness 是傳遞給 torch.vmap() 的隨機性選項。

  • 一個 in_dims 元組作為第二個參數。對於 args 中的每個 arg,in_dims 都有一個對應的 Optional[int]。如果 arg 不是 Tensor,或者 arg 沒有被 vmap,則它是 None,否則,它是一個整數,指定 Tensor 的哪個維度正在被 vmap。

  • *args,它與傳遞給 forward() 的 args 相同。

vmap 靜態方法的返回值是一個 (output, out_dims) 的元組。與 in_dims 類似,out_dims 應該與 output 具有相同的結構,並且每個輸出包含一個 out_dim,用於指定輸出是否具有 vmap 維度以及它在哪個索引中。

請參閱 使用 autograd.Function 擴充 torch.func 以獲取更多詳細訊息。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源