快捷方式

DistributionalQValueHook

class torchrl.modules.DistributionalQValueHook(action_space: str, support: Tensor, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None)[source]

用於 Q 值策略的分佈式 Q 值掛鉤。

給定映射運算子的輸出,表示不同動作值區間的對數機率,DistributionalQValueHook 將使用提供的支持將這些值轉換為其 argmax 分量。

有關分佈式 DQN 的更多詳細資訊,請參閱「強化學習的分佈式視角」,https://arxiv.org/pdf/1707.06887.pdf

參數:
  • action_space (str) – 動作空間。必須是 "one-hot""mult-one-hot""binary""categorical" 之一。

  • action_value_key (strstr 元組, 選用) – 在 TensorDictModule 上掛鉤時使用。表示動作值的輸入鍵。預設為 "action_value"

  • action_mask_key (strstr 元組, 選用) – 表示動作遮罩的輸入鍵。預設為 "None" (等同於不遮罩)。

  • support (torch.Tensor) – 動作值的支持。

  • var_nums (int, optional) – 如果 action_space = "mult-one-hot",此值代表每個動作組件的基數。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> class CustomDistributionalQval(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.Linear(4, nbins*4)
...
...     def forward(self, x):
...         return self.linear(x).view(-1, nbins, 4).log_softmax(-2)
...
>>> module = CustomDistributionalQval()
>>> params = TensorDict.from_module(module)
>>> action_spec = OneHot(4)
>>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
>>> module.register_forward_hook(hook)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> with params.to_module(module):
...     qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(torch.Size([5, 4]), dtype=torch.int64),
        action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32),
        observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深度教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源