QValueHook¶
- class torchrl.modules.QValueHook(action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None)[原始碼]¶
Q 值策略的 Q 值掛鉤。
給定一個常規 nn.Module 的輸出,表示不同可用離散動作的值,QValueHook 會將這些值轉換為其 argmax 分量(即產生的貪婪動作)。
- 參數:
action_space (str) – 動作空間。必須是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。var_nums (int, optional) – 如果
action_space = "mult-one-hot"
,則此值表示每個動作分量的基數。action_value_key (str 或 str 元組, optional) – 在掛鉤到 TensorDictModule 時使用。表示動作值的輸入鍵。預設為
"action_value"
。action_mask_key (str 或 str 的 tuple,選填) – 代表動作遮罩的輸入鍵。預設為
"None"
(相當於沒有遮罩)。out_keys (str 的 list 或 str 的 tuple,選填) – 在 TensorDictModule 上使用時使用。代表動作、動作值和所選動作值的輸出鍵。預設為
["action", "action_value", "chosen_action_value"]
。
範例
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) >>> hook = QValueHook("one_hot") >>> module.register_forward_hook(hook) >>> action_spec = OneHot(4) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)