捷徑

QValueActor

class torchrl.modules.tensordict_module.QValueActor(*args, **kwargs)[來源]

Q 值行動者類別。

此類別將 QValueModule 附加在輸入模組之後,以便使用動作值來選擇動作。

參數:

module (nn.Module) – 用於將輸入映射到輸出參數空間的 torch.nn.Module。如果提供的類別與 tensordict.nn.TensorDictModuleBase 不相容,它將被包裝在 tensordict.nn.TensorDictModule 中,其中 in_keys 由以下關鍵字引數指示。

關鍵字引數:
  • in_keys (str 的可迭代物件, 選用) – 如果提供的類別與 tensordict.nn.TensorDictModuleBase 不相容,則此鍵清單指示需要將哪些觀測值傳遞給包裝模組以取得動作值。預設為 ["observation"]

  • spec (TensorSpec, 選用) – 僅限關鍵字的引數。輸出張量的規格。如果模組輸出多個輸出張量,則 spec 描述第一個輸出張量的空間。

  • safe (bool) – 僅限關鍵字的引數。如果 True,則會根據輸入規格檢查輸出的值。由於探索策略或數值上溢/下溢問題,可能會發生超出範圍的取樣。如果此值超出範圍,則使用 TensorSpec.project 方法將其投影回所需的空間。預設值為 False

  • action_space (str, 選用) – 動作空間。必須是 "one-hot""mult-one-hot""binary""categorical" 之一。此引數與 spec 互斥,因為 spec 限制了 action_space。

  • action_value_key (strstr 的元組, 選用) – 如果輸入模組是 tensordict.nn.TensorDictModuleBase 實例,則它必須符合其輸出鍵之一。否則,此字串表示輸出 tensordict 中動作值條目的名稱。

  • action_mask_key (strstr 的 tuple, 選填) – 代表動作遮罩的輸入鍵。預設為 "None" (相當於沒有遮罩)。

注意

不能傳入 out_keys。如果模組是 tensordict.nn.TensorDictModule 的實例,則 out_keys 會相應地更新。對於常規的 torch.nn.Module 實例,將使用三元組 ["action", action_value_key, "chosen_action_value"]

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # with a regular nn.Module
>>> module = nn.Linear(4, 4)
>>> action_spec = OneHot(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)
>>> # with a TensorDictModule
>>> td = TensorDict({'obs': torch.randn(5, 4)}, [5])
>>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"])
>>> action_spec = OneHot(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源