functorch.combine_state_for_ensemble¶
-
functorch.
combine_state_for_ensemble
(models) → func, params, buffers[source]¶ 準備一份將
vmap()
與torch.nn.Modules
進行整合使用的清單。給定一個相同類別的
M
nn.Modules
清單,將它們的所有參數與緩衝區堆疊,以產生params
與buffers
。結果中的每一個參數與緩衝區都會有一個額外大小為M
。的維度。combine_state_for_ensemble()
同時也會傳回func
,也就是models
中的其中一個模型的功能版本。您無法直接執行func(params, buffers, *args, **kwargs)
,您應該要使用vmap(func, ...)(params, buffers, *args, **kwargs)
這裡提供一個如何在非常簡單的模型中進行整合的範例
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) fmodel, params, buffers = combine_state_for_ensemble(models) output = vmap(fmodel, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)
警告
堆疊的所有模組都必須相同(除了其參數/緩衝區的值以外)。例如,它們應該處在相同的模式(訓練相對於評估)。
此 API 可能會變更,我們正在探索建立整合體的更佳方法,並且熱烈歡迎您提供改善建議。
警告
我們已將 functorch 整合到 PyTorch。作為整合的最後一步,functorch.combine_state_for_ensemble 已從 PyTorch 2.0 版開始進行棄用,且將在 PyTorch >= 2.3 的後續版本中刪除。請改用 torch.func.stack_module_state;請參閱 PyTorch 2.0 發行說明和/或 torch.func 移轉指南以取得更多詳細資訊 https://pytorch.dev.org.tw/docs/master/func.migrating.html