快捷方式

torchrl.trainers 封包

trainer 封包提供公用程式來編寫可重複使用的訓練腳本。核心思想是使用一個訓練器,它實現了一個嵌套迴圈,其中外迴圈運行資料收集步驟,而內迴圈運行優化步驟。我們相信這適用於多種 RL 訓練方案,例如在策略、離策略、基於模型和無模型解決方案、離線 RL 等。更特殊的情況,例如元 RL 演算法,可能具有實質上不同的訓練方案。

trainer.train() 方法可以草繪如下

訓練器迴圈
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

訓練器迴圈中可以使用 10 個鉤子:"batch_process""pre_optim_steps""process_optim_batch""post_loss""post_steps""post_optim""pre_steps_log""post_steps_log""post_optim_log""optimizer"。它們在應用它們的註釋中指出。鉤子可以分為 3 類:資料處理"batch_process""process_optim_batch")、日誌記錄"pre_steps_log""post_optim_log""post_steps_log")和 操作 鉤子("pre_optim_steps""post_loss""post_optim""post_steps")。

  • 資料處理 Hook 會更新一個資料的 tensordict。Hook 的 __call__ 方法應該接受一個 TensorDict 物件作為輸入,並根據某些策略更新它。此類 Hook 的範例包括 Replay Buffer 擴展 (ReplayBufferTrainer.extend)、資料正規化 (包括正規化常數更新)、資料子採樣 (:class:~torchrl.trainers.BatchSubSampler) 等。

  • 記錄 Hook 接受一批資料,以 TensorDict 的形式呈現,並在記錄器中寫入從該資料檢索的一些資訊。範例包括 Recorder Hook、獎勵記錄器 (LogReward) 等。Hook 應該回傳一個字典 (或 None 值),其中包含要記錄的資料。鍵 "log_pbar" 保留給布林值,用於指示記錄的值是否應顯示在訓練日誌上印出的進度條上。

  • 操作 Hook 是對模型、資料收集器、目標網路更新等執行特定操作的 Hook。例如,使用 UpdateWeights 同步收集器的權重,或使用 ReplayBufferTrainer.update_priority 更新 Replay Buffer 的優先順序,都是操作 Hook 的範例。它們與資料無關 (它們不需要 TensorDict 輸入),它們應該只是在每次迭代 (或每 N 次迭代) 執行一次。

TorchRL 提供的 Hook 通常繼承自一個共同的抽象類別 TrainerHookBase,並且都實作了三個基本方法:用於檢查點的 state_dictload_state_dict 方法,以及一個 register 方法,該方法在訓練器中以預設值註冊 Hook。此方法接受一個訓練器和一個模組名稱作為輸入。例如,以下記錄 Hook 每 10 次呼叫 "post_optim_log" 時執行

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

檢查點

訓練器類別和 Hook 支援檢查點,可以使用 torchsnapshot 後端或常規的 torch 後端來實現。這可以透過全域變數 CKPT_BACKEND 來控制。

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 預設為 torch。torchsnapshot 相對於 pytorch 的優點在於它是一個更靈活的 API,它支援分散式檢查點,並且還允許使用者將磁碟上儲存的檔案中的張量載入到具有物理儲存體的張量 (pytorch 目前不支援)。例如,這允許從一個可能無法容納在記憶體中的 Replay Buffer 載入張量並將張量載入到其中。

在建構訓練器時,可以提供要寫入檢查點的路徑。使用 torchsnapshot 後端,預期的是目錄路徑,而 torch 後端預期的是檔案路徑 (通常是 .pt 檔案)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用於執行上述迴圈及其所有 Hook,但僅將 Trainer 類別用於其檢查點功能也是完全有效的用法。

訓練器和 Hook

BatchSubSampler(batch_size[, sub_traj_len, ...])

用於線上 RL sota 實現的資料子採樣器。

ClearCudaCache(interval)

以給定的間隔清除 CUDA 快取。

CountFramesLog(*args, **kwargs)

一個幀計數器 Hook。

LogReward([logname, log_pbar, reward_key])

獎勵記錄器 Hook。

OptimizerHook(optimizer[, loss_components])

為一個或多個損失元件新增最佳化器。

Recorder(*, record_interval, record_frames)

用於 Trainer 的記錄器 Hook。

ReplayBufferTrainer(replay_buffer[, ...])

Replay Buffer Hook 提供者。

RewardNormalizer([decay, scale, eps, ...])

獎勵正規化器 Hook。

SelectKeys(keys)

在 TensorDict 批次中選擇鍵。

Trainer(*args, **kwargs)

一個通用的訓練器類別。

TrainerHookBase()

torchrl 訓練器類別的抽象 Hook 類別。

UpdateWeights(collector, update_weights_interval)

一個收集器權重更新 Hook 類別。

建構器

make_collector_offpolicy(make_env, ...[, ...])

回傳一個用於 off-policy sota 實現的資料收集器。

make_collector_onpolicy(make_env, ...[, ...])

在 on-policy 設定中建立一個收集器。

make_dqn_loss(model, cfg)

建構 DQN 損失模組。

make_replay_buffer(device, cfg)

使用從 ReplayArgsConfig 建構的設定檔來建立回放緩衝區 (replay buffer)。

make_target_updater(cfg, loss_module)

建立一個目標網路權重更新物件。

make_trainer(collector, loss_module[, ...])

根據其組成部分建立 Trainer 實例。

parallel_env_constructor(cfg, **kwargs)

從使用適當的解析器建構子建構的 argparse.Namespace 返回一個並行環境。

sync_async_collector(env_fns, env_kwargs[, ...])

運行非同步收集器,每個收集器運行同步環境。

sync_sync_collector(env_fns, env_kwargs[, ...])

運行同步收集器,每個收集器運行同步環境。

transformed_env_constructor(cfg[, ...])

從使用適當的解析器建構子建構的 argparse.Namespace 返回一個環境建立器。

工具程式

correct_for_frame_skip(cfg)

通過將所有反映影格計數的參數除以 frame_skip,來修正輸入 frame_skip 的參數。

get_stats_random_rollout(cfg[, ...])

使用隨機 rollouts 從環境收集統計資訊(loc 和 scale)。

記錄器

Logger(exp_name, log_dir)

記錄器的範本。

csv.CSVLogger(exp_name[, log_dir, ...])

最小依賴性的 CSV 記錄器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 記錄器的包裝器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensoarboard 記錄器的包裝器。

wandb.WandbLogger(*args, **kwargs)

wandb 記錄器的包裝器。

get_logger(logger_type, logger_name, ...)

取得提供的 logger_type 的記錄器實例。

generate_exp_name(model_name, experiment_name)

使用 UUID 和目前日期為描述的實驗產生 ID (str)。

錄製工具程式

錄製工具程式的詳細資訊請參考此處

VideoRecorder(logger, tag[, in_keys, skip, ...])

影片錄製器轉換。

TensorDictRecorder(out_file_base[, ...])

TensorDict 記錄器。

PixelRenderTransform([out_keys, preproc, ...])

一個在父環境上呼叫 render 並在 tensordict 中註冊像素觀測的轉換。

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源