快捷方式

QValueHook

class torchrl.modules.QValueHook(action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None)[原始碼]

Q 值策略的 Q 值掛鉤。

給定一個常規 nn.Module 的輸出,表示不同可用離散動作的值,QValueHook 會將這些值轉換為其 argmax 分量(即產生的貪婪動作)。

參數:
  • action_space (str) – 動作空間。必須是 "one-hot""mult-one-hot""binary""categorical" 之一。

  • var_nums (int, optional) – 如果 action_space = "mult-one-hot",則此值表示每個動作分量的基數。

  • action_value_key (strstr 元組, optional) – 在掛鉤到 TensorDictModule 時使用。表示動作值的輸入鍵。預設為 "action_value"

  • action_mask_key (strstr 的 tuple選填) – 代表動作遮罩的輸入鍵。預設為 "None" (相當於沒有遮罩)。

  • out_keys (str 的 liststr 的 tuple選填) – 在 TensorDictModule 上使用時使用。代表動作、動作值和所選動作值的輸出鍵。預設為 ["action", "action_value", "chosen_action_value"]

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> module = nn.Linear(4, 4)
>>> hook = QValueHook("one_hot")
>>> module.register_forward_hook(hook)
>>> action_spec = OneHot(4)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得適合初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源