DistributionalQValueHook¶
- class torchrl.modules.DistributionalQValueHook(action_space: str, support: Tensor, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None)[source]¶
用於 Q 值策略的分佈式 Q 值掛鉤。
給定映射運算子的輸出,表示不同動作值區間的對數機率,DistributionalQValueHook 將使用提供的支持將這些值轉換為其 argmax 分量。
有關分佈式 DQN 的更多詳細資訊,請參閱「強化學習的分佈式視角」,https://arxiv.org/pdf/1707.06887.pdf
- 參數:
action_space (str) – 動作空間。必須是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。action_value_key (str 或 str 元組, 選用) – 在 TensorDictModule 上掛鉤時使用。表示動作值的輸入鍵。預設為
"action_value"
。action_mask_key (str 或 str 元組, 選用) – 表示動作遮罩的輸入鍵。預設為
"None"
(等同於不遮罩)。support (torch.Tensor) – 動作值的支持。
var_nums (int, optional) – 如果
action_space = "mult-one-hot"
,此值代表每個動作組件的基數。
範例
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 >>> class CustomDistributionalQval(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(4, nbins*4) ... ... def forward(self, x): ... return self.linear(x).view(-1, nbins, 4).log_softmax(-2) ... >>> module = CustomDistributionalQval() >>> params = TensorDict.from_module(module) >>> action_spec = OneHot(4) >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) >>> module.register_forward_hook(hook) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) >>> with params.to_module(module): ... qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(torch.Size([5, 4]), dtype=torch.int64), action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32), observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, batch_size=torch.Size([5]), device=None, is_shared=False)