快捷方式

GAILLoss

class torchrl.objectives.GAILLoss(*args, **kwargs)[來源]

生成對抗模仿學習 (GAIL) 損失的 TorchRL 實作。

“生成對抗模仿學習” <https://arxiv.org/pdf/1606.03476> 中提出

參數:

discriminator_network (TensorDictModule) – 隨機 actor

關鍵字引數:
  • use_grad_penalty (bool, optional) – 是否使用梯度懲罰。預設值: False

  • gp_lambda (float, optional) – 梯度懲罰 lambda。預設值: 10

  • reduction (str, optional) – 指定要套用到輸出的縮減方式: "none" | "mean" | "sum""none":不會套用任何縮減, "mean":輸出的總和將除以輸出中的元素數量, "sum":將對輸出求和。預設值: "mean"

forward(tensordict: TensorDictBase = None) TensorDictBase[來源]

forward 方法。

如果 use_grad_penalty 設定為 True,則計算辨別器損失和梯度懲罰。 如果 use_grad_penalty 設定為 True,則也會傳回分離的梯度懲罰損失以供記錄。 若要查看輸入 tensordict 中預期的金鑰以及預期作為輸出的金鑰,請查看類別的 “in_keys”“out_keys” 屬性。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源