捷徑

ReplayBufferTrainer

class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: Optional[int] = None, memmap: bool = False, device: Union[device, str, int] = 'cpu', flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None)[原始碼]

重播緩衝區 Hook 提供者。

參數:
  • replay_buffer (TensorDictReplayBuffer) – 要使用的重播緩衝區。

  • batch_size (int, optional) – 從最新集合或重播緩衝區採樣資料時的批次大小。 如果未提供,將使用重播緩衝區批次大小(對於未變更的批次大小,這是首選選項)。

  • memmap (bool, optional) – 如果 True,則會建立一個 memmap tensordict。 預設值為 False

  • device (device, optional) – 樣本必須放置的裝置。 預設值為 cpu

  • flatten_tensordicts (bool, optional) – 如果 True,tensordicts 在傳遞到重播緩衝區之前會被扁平化 (或等效地使用從收集器獲得的有效遮罩進行遮罩)。 否則,除了填充之外,不會進行任何轉換 (請參閱下面的 max_dims 參數)。預設為 False

  • max_dims (sequence of int, optional) – 如果 flatten_tensordicts 設定為 False,這將是一個列表,其長度等於所提供 tensordicts 的 batch_size,代表每個 tensordict 的最大尺寸。 如果提供此列表,則將使用此列表中的尺寸來填充 tensordict,並使其形狀在傳遞到重播緩衝區之前匹配。 如果沒有最大值,則應提供 -1 值。

範例

>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N)
>>> trainer.register_op("batch_process", rb_trainer.extend)
>>> trainer.register_op("process_optim_batch", rb_trainer.sample)
>>> trainer.register_op("post_loss", rb_trainer.update_priority)
register(trainer: Trainer, name: str = 'replay_buffer')[source]

在訓練器中以預設位置註冊 hook。

參數:
  • trainer (Trainer) – 必須註冊 hook 的訓練器。

  • name (str) – hook 的名稱。

注意

若要在預設位置以外的其他位置註冊 hook,請使用 register_op()

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您的問題解答

檢視資源