快捷方式

DistributionalQValueActor

class torchrl.modules.tensordict_module.DistributionalQValueActor(*args, **kwargs)[來源]

Distributional DQN 執行者類別。

此類別在輸入模組後附加一個 QValueModule,以便使用動作值來選擇動作。

參數:

module (nn.Module) – 一個 torch.nn.Module,用於將輸入映射到輸出參數空間。 如果模組不是 torchrl.modules.DistributionalDQNnet 類型,DistributionalQValueActor 將確保對動作值張量沿著維度 -2 應用 log-softmax 運算。 可以透過關閉 make_log_softmax 關鍵字引數來停用此功能。

關鍵字引數:
  • in_keys (字串的可迭代物件, 選用) – 從輸入 tensordict 讀取並傳遞到模組的鍵。 如果它包含多個元素,則值將按照 in_keys 可迭代物件給定的順序傳遞。 預設為 ["observation"]

  • spec (TensorSpec, 選用) – 僅限關鍵字的引數。 輸出張量的規格。 如果模組輸出多個輸出張量,則 spec 表徵第一個輸出張量的空間。

  • safe (bool) – 僅限關鍵字的引數。 如果為 True,則會根據輸入 spec 檢查輸出的值。 由於探索策略或數值下溢/溢位問題,可能會發生超出範圍的採樣。 如果此值超出範圍,則會使用 TensorSpec.project 方法將其投影回所需的空間。 預設值為 False

  • var_nums (int, 選用) – 如果 action_space = "mult-one-hot",則此值表示每個動作元件的基數。

  • support (torch.Tensor) – 動作值的支援。

  • action_space (str, optional) – 動作空間。必須為 "one-hot""mult-one-hot""binary""categorical" 其中之一。此參數與 spec 互斥,因為 spec 會限制動作空間。

  • make_log_softmax (bool, optional) – 如果 True 且該模組不是 torchrl.modules.DistributionalDQNnet 類型,則會在動作值張量的 -2 維度上應用 log-softmax 運算。

  • action_value_key (str or tuple of str, optional) – 如果輸入模組是 tensordict.nn.TensorDictModuleBase 實例,則必須與其輸出鍵之一匹配。 否則,此字串表示輸出 tensordict 中動作值條目的名稱。

  • action_mask_key (str or tuple of str, optional) – 代表動作遮罩的輸入鍵。預設為 "None" (等同於沒有遮罩)。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> # let us make sure that the output is a log-softmax
>>> module = TensorDictSequential(
...     TensorDictModule(module, ["observation"], ["action_value"]),
...     TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]),
... )
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(
...     module=module,
...     spec=action_spec,
...     support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源