DreamerModelLoss¶
- class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[原始碼]¶
Dreamer 模型損失。
計算 dreamer 世界模型的損失。損失由 RSSM 的先前與後驗之間的 KL 散度、重建觀測值的重建損失以及預測獎勵的獎勵損失組成。
參考文獻:https://arxiv.org/abs/1912.01603。
- 參數:
world_model (TensorDictModule) – 世界模型。
lambda_kl (float, optional) – KL 散度損失的權重。預設值:1.0。
lambda_reco (float, optional) – 重建損失的權重。預設值:1.0。
lambda_reward (float, optional) – 獎勵損失的權重。預設值:1.0。
reco_loss (str, optional) – 重建損失。預設值:“l2”。
reward_loss (str, optional) – 獎勵損失。預設值:“l2”。
free_nats (int, optional) – 自由納特。預設值:3。
delayed_clamp (bool, optional) – 如果
True
,則在平均後進行 KL 鉗位。如果為 False(預設),則首先將 KL 散度鉗位到自由納特值,然後再進行平均。global_average (bool, optional) – 如果
True
,則損失將在所有維度上平均。否則,將對所有非批次/時間維度執行總和,並對批次和時間進行平均。預設值:False。