GAILLoss¶
- class torchrl.objectives.GAILLoss(*args, **kwargs)[來源]¶
生成對抗模仿學習 (GAIL) 損失的 TorchRL 實作。
在 “生成對抗模仿學習” <https://arxiv.org/pdf/1606.03476> 中提出
- 參數:
discriminator_network (TensorDictModule) – 隨機 actor
- 關鍵字引數:
use_grad_penalty (bool, optional) – 是否使用梯度懲罰。預設值:
False
。gp_lambda (float, optional) – 梯度懲罰 lambda。預設值:
10
。reduction (str, optional) – 指定要套用到輸出的縮減方式:
"none"
|"mean"
|"sum"
。"none"
:不會套用任何縮減,"mean"
:輸出的總和將除以輸出中的元素數量,"sum"
:將對輸出求和。預設值:"mean"
。