QValueActor¶
- class torchrl.modules.tensordict_module.QValueActor(*args, **kwargs)[來源]¶
Q 值行動者類別。
此類別將
QValueModule
附加在輸入模組之後,以便使用動作值來選擇動作。- 參數:
module (nn.Module) – 用於將輸入映射到輸出參數空間的
torch.nn.Module
。如果提供的類別與tensordict.nn.TensorDictModuleBase
不相容,它將被包裝在tensordict.nn.TensorDictModule
中,其中in_keys
由以下關鍵字引數指示。- 關鍵字引數:
in_keys (str 的可迭代物件, 選用) – 如果提供的類別與
tensordict.nn.TensorDictModuleBase
不相容,則此鍵清單指示需要將哪些觀測值傳遞給包裝模組以取得動作值。預設為["observation"]
。spec (TensorSpec, 選用) – 僅限關鍵字的引數。輸出張量的規格。如果模組輸出多個輸出張量,則 spec 描述第一個輸出張量的空間。
safe (bool) – 僅限關鍵字的引數。如果
True
,則會根據輸入規格檢查輸出的值。由於探索策略或數值上溢/下溢問題,可能會發生超出範圍的取樣。如果此值超出範圍,則使用TensorSpec.project
方法將其投影回所需的空間。預設值為False
。action_space (str, 選用) – 動作空間。必須是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。此引數與spec
互斥,因為spec
限制了 action_space。action_value_key (str 或 str 的元組, 選用) – 如果輸入模組是
tensordict.nn.TensorDictModuleBase
實例,則它必須符合其輸出鍵之一。否則,此字串表示輸出 tensordict 中動作值條目的名稱。action_mask_key (str 或 str 的 tuple, 選填) – 代表動作遮罩的輸入鍵。預設為
"None"
(相當於沒有遮罩)。
注意
不能傳入
out_keys
。如果模組是tensordict.nn.TensorDictModule
的實例,則 out_keys 會相應地更新。對於常規的torch.nn.Module
實例,將使用三元組["action", action_value_key, "chosen_action_value"]
。範例
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> # with a regular nn.Module >>> module = nn.Linear(4, 4) >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) >>> # with a TensorDictModule >>> td = TensorDict({'obs': torch.randn(5, 4)}, [5]) >>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"]) >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)