捷徑

DreamerModelLoss

class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[原始碼]

Dreamer 模型損失。

計算 dreamer 世界模型的損失。損失由 RSSM 的先前與後驗之間的 KL 散度、重建觀測值的重建損失以及預測獎勵的獎勵損失組成。

參考文獻:https://arxiv.org/abs/1912.01603

參數:
  • world_model (TensorDictModule) – 世界模型。

  • lambda_kl (float, optional) – KL 散度損失的權重。預設值:1.0。

  • lambda_reco (float, optional) – 重建損失的權重。預設值:1.0。

  • lambda_reward (float, optional) – 獎勵損失的權重。預設值:1.0。

  • reco_loss (str, optional) – 重建損失。預設值:“l2”。

  • reward_loss (str, optional) – 獎勵損失。預設值:“l2”。

  • free_nats (int, optional) – 自由納特。預設值:3。

  • delayed_clamp (bool, optional) – 如果 True,則在平均後進行 KL 鉗位。如果為 False(預設),則首先將 KL 散度鉗位到自由納特值,然後再進行平均。

  • global_average (bool, optional) – 如果 True,則損失將在所有維度上平均。否則,將對所有非批次/時間維度執行總和,並對批次和時間進行平均。預設值:False。

forward(tensordict: TensorDict) Tensor[原始碼]

它旨在讀取輸入 TensorDict 並傳回另一個具有名為“loss*”的損失鍵的 tensordict。

然後,訓練器可以使用將損失拆分為其組成部分來記錄整個訓練過程中的各種損失值。輸出 tensordict 中存在的其他純量也將被記錄。

參數:

tensordict – 具有計算損失所需值的輸入 tensordict。

傳回:

一個沒有批次維度的新 tensordict,包含各種將被命名為 “loss*” 的損失純量。 重要的是,損失必須以此名稱返回,因為訓練器會在反向傳播之前讀取它們。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源