快捷方式

VmapModule

class torchrl.modules.VmapModule(*args, **kwargs)[來源]

TensorDictModule 包裝器,用於在輸入上進行 vmap。

它旨在與接受比所提供資料少一個批次維度的模組一起使用。 透過使用此包裝器,可以隱藏批次維度並滿足包裝的模組。

參數:
  • module (TensorDictModuleBase) – 要進行 vmap 的模組。

  • vmap_dim (int, optional) – vmap 輸入和輸出維度。 如果未提供,則假設為 tensordict 的最後一個維度。

注意

由於 vmap 需要控制輸入的批次大小,因此此模組不支援已分派的引數

範例

>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
forward(tensordict)[來源]

定義每次呼叫時執行的計算。

應由所有子類別覆寫。

注意

雖然正向傳遞的配方需要在這個函式中定義,但應該在之後呼叫 Module 實例而不是這個,因為前者負責執行註冊的 hook,而後者會靜默地忽略它們。

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源