VmapModule¶
- class torchrl.modules.VmapModule(*args, **kwargs)[來源]¶
TensorDictModule 包裝器,用於在輸入上進行 vmap。
它旨在與接受比所提供資料少一個批次維度的模組一起使用。 透過使用此包裝器,可以隱藏批次維度並滿足包裝的模組。
- 參數:
module (TensorDictModuleBase) – 要進行 vmap 的模組。
vmap_dim (int, optional) – vmap 輸入和輸出維度。 如果未提供,則假設為 tensordict 的最後一個維度。
注意
由於 vmap 需要控制輸入的批次大小,因此此模組不支援已分派的引數
範例
>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) >>> sample_in = torch.ones((10,3,2)) >>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) >>> lam(sample_in) >>> vm = VmapModule(lam, 0) >>> vm(sample_in_td) >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()