捷徑

DistributionalDQNLoss

class torchrl.objectives.DistributionalDQNLoss(*args, **kwargs)[source]

一個 distributional DQN 損失類別。

Distributional DQN 使用一個價值網路,該網路輸出一系列離散折扣回報支援上的價值分佈(與常規 DQN 不同,後者的價值網路輸出折扣回報的單一點預測)。

有關 Distributional DQN 的更多詳細資訊,請參閱「強化學習的分佈式視角」,https://arxiv.org/pdf/1707.06887.pdf

參數:
  • value_network (DistributionalQValueActornn.Module) – distributional Q 價值運算子。

  • gamma (純量) –

    用於回報計算的折扣因子。 .. note

    Unlike :class:`DQNLoss`, this class does not currently support
    custom value functions. The next value estimation is always
    bootstrapped.
    

  • delay_value (bool) – 是否將價值網路複製到一個新的目標價值網路中,以建立雙重 DQN

  • priority_key (str, optional) – [已棄用,請改用 .set_keys(priority_key=priority_key)] 假設優先順序儲存在新增至此 ReplayBuffer 的 TensorDict 中的鍵。這用於當取樣器屬於 PrioritizedSampler 類型時。預設值為 "td_error"

  • reduction (str, optional) – 指定要應用於輸出的縮減:"none" | "mean" | "sum""none":不應用任何縮減,"mean":輸出總和將除以輸出中的元素數,"sum":將對輸出求和。預設值:"mean"

forward(input_tensordict: TensorDictBase) TensorDict[原始碼]

它被設計用來讀取一個輸入 TensorDict,並返回另一個帶有 "loss*" 命名之損失鍵的 tensordict。

將損失分解成其組成部分,然後可被訓練器用來記錄整個訓練過程中的各種損失值。輸出 tensordict 中存在的其他純量也會被記錄。

參數:

tensordict – 一個輸入 tensordict,包含計算損失所需的值。

回傳值:

一個新的 tensordict,沒有批次維度,包含各種損失純量,這些純量將被命名為 "loss*"。損失必須以這個名稱回傳,因為它們會在反向傳播之前被訓練器讀取。

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[原始碼]

Value-function 建構子。

如果想要非預設的值函數,則必須使用此方法建立。

參數:
  • value_type (ValueEstimators) – 一個 ValueEstimators 列舉類型,指示要使用的值函數。如果未提供,將使用儲存在 default_value_estimator 屬性中的預設值。產生的值估計器類別將在 self.value_type 中註冊,以便將來進行改進。

  • **hyperparams – 用於值函數的超參數。如果未提供,將使用 default_value_kwargs() 指示的值。

範例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)

文件

訪問 PyTorch 的完整開發者文件

查看文件

教學課程

獲取針對初學者和高級開發人員的深入教學課程

查看教學課程

資源

查找開發資源並獲得問題解答

查看資源