快捷方式

Trainer

class torchrl.trainers.Trainer(*args, **kwargs)[原始碼]

通用的 Trainer 類別。

Trainer 負責收集資料並訓練模型。為了使類別盡可能地通用,Trainer 不會建構任何特定操作:它們都必須在訓練迴圈中的特定點掛鉤。

若要建構 Trainer,需要一個可迭代的資料來源 (collector)、一個損失模組和一個最佳化器。

參數:
  • collector (Sequence[TensorDictBase]) – 一個可迭代物件,以 TensorDict 形式傳回形狀為 [batch x time steps] 的資料批次。

  • total_frames (int) – 訓練期間要收集的總幀數。

  • loss_module (LossModule) – 一個模組,讀取 TensorDict 批次(可能從重播緩衝區取樣),並傳回一個損失 TensorDict,其中每個鍵指向不同的損失元件。

  • optimizer (optim.Optimizer) – 一個訓練模型參數的最佳化器。

  • logger (Logger, optional) – 一個將處理記錄的 Logger。

  • optim_steps_per_batch (int) – 每次資料收集的最佳化步驟數。一個 trainer 的運作方式如下:一個主迴圈收集資料批次(epoch 迴圈),而一個子迴圈(訓練迴圈)在兩次資料收集之間執行模型更新。

  • clip_grad_norm (bool, optional) – 如果為 True,梯度將根據模型參數的總範數進行裁剪。如果為 False,所有偏導數將被鉗制為 (-clip_norm, clip_norm)。預設值為 True

  • clip_norm (Number, optional) – 用於裁剪梯度的值。預設值為 None (無裁剪範數)。

  • progress_bar (bool, optional) – 若為 True,將會使用 tqdm 顯示進度條。如果沒有安裝 tqdm,此選項將不會有任何作用。預設值為 True

  • seed (int, optional) – 用於 collector、pytorch 和 numpy 的 seed。預設值為 None

  • save_trainer_interval (int, optional) – Trainer 儲存到磁碟的頻率,以 frame 數計算。預設值為 10000。

  • log_interval (int, optional) – 數值被記錄的頻率,以 frame 數計算。預設值為 10000。

  • save_trainer_file (path, optional) – 儲存 trainer 的路徑。預設值為 None (不儲存)。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您問題的解答

檢視資源