捷徑

functorch.combine_state_for_ensemble

functorch.combine_state_for_ensemble(models)func, params, buffers[source]

準備一份將 vmap()torch.nn.Modules 進行整合使用的清單。

給定一個相同類別的 M nn.Modules 清單,將它們的所有參數與緩衝區堆疊,以產生 paramsbuffers。結果中的每一個參數與緩衝區都會有一個額外大小為 M。的維度。

combine_state_for_ensemble() 同時也會傳回 func,也就是 models 中的其中一個模型的功能版本。您無法直接執行 func(params, buffers, *args, **kwargs),您應該要使用 vmap(func, ...)(params, buffers, *args, **kwargs)

這裡提供一個如何在非常簡單的模型中進行整合的範例

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)

警告

堆疊的所有模組都必須相同(除了其參數/緩衝區的值以外)。例如,它們應該處在相同的模式(訓練相對於評估)。

此 API 可能會變更,我們正在探索建立整合體的更佳方法,並且熱烈歡迎您提供改善建議。

警告

我們已將 functorch 整合到 PyTorch。作為整合的最後一步,functorch.combine_state_for_ensemble 已從 PyTorch 2.0 版開始進行棄用,且將在 PyTorch >= 2.3 的後續版本中刪除。請改用 torch.func.stack_module_state;請參閱 PyTorch 2.0 發行說明和/或 torch.func 移轉指南以取得更多詳細資訊 https://pytorch.dev.org.tw/docs/master/func.migrating.html

文件

取得 PyTorch 的全方位開發人員文件。

檢視文件

教學

取得初學者與進階開發人員的深入教學。

檢視教學課程

資源

找出開發資源,並取得問題解答。

檢視資源