QValueModule¶
- class torchrl.modules.tensordict_module.QValueModule(*args, **kwargs)[source]¶
用於 Q 值策略的 Q 值 TensorDictModule。
此模組根據給定的動作空間(one-hot、binary 或 categorical),將包含動作值的張量處理為其 argmax 組件(即產生的貪婪動作)。它適用於 tensordict 和常規張量。
- 參數:
action_space (str, optional) – 動作空間。必須是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。此引數與spec
互斥,因為spec
會影響 action_space。action_value_key (str or tuple of str, optional) – 代表動作值的輸入鍵。預設為
"action_value"
。action_mask_key (str or tuple of str, optional) – 代表動作遮罩的輸入鍵。預設為
"None"
(相當於不遮罩)。out_keys (list of str or tuple of str, optional) – 代表動作、動作值和所選動作值的輸出鍵。預設為
["action", "action_value", "chosen_action_value"]
。var_nums (int, optional) – 如果
action_space = "mult-one-hot"
,則此值表示每個動作組件的基數。spec (TensorSpec, optional) – 如果提供,則為動作(和/或其他輸出)的規格。這與
action_space
互斥,因為規格會影響動作空間。safe (bool) – 如果
True
,則輸出值會根據輸入規格進行檢查。由於探索策略或數值上的過小/溢位問題,可能會發生超出範圍的取樣。如果此值超出範圍,它會使用TensorSpec.project
方法投影回所需的空間。預設值為False
。
- 回傳:
如果輸入是單個張量,則會回傳一個包含所選動作、值以及所選動作值的 triplet。如果提供了 tensordict,則它會以
out_keys
欄位指示的鍵更新這些條目。
範例
>>> from tensordict import TensorDict >>> action_space = "categorical" >>> action_value_key = "my_action_value" >>> actor = QValueModule(action_space, action_value_key=action_value_key) >>> # This module works with both tensordict and regular tensors: >>> value = torch.zeros(4) >>> value[-1] = 1 >>> actor(my_action_value=value) (tensor(3), tensor([0., 0., 0., 1.]), tensor([1.])) >>> actor(value) (tensor(3), tensor([0., 0., 0., 1.]), tensor([1.])) >>> actor(TensorDict({action_value_key: value}, [])) TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), my_action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)