DistributionalQValueActor¶
- class torchrl.modules.tensordict_module.DistributionalQValueActor(*args, **kwargs)[來源]¶
Distributional DQN 執行者類別。
此類別在輸入模組後附加一個
QValueModule
,以便使用動作值來選擇動作。- 參數:
module (nn.Module) – 一個
torch.nn.Module
,用於將輸入映射到輸出參數空間。 如果模組不是torchrl.modules.DistributionalDQNnet
類型,DistributionalQValueActor
將確保對動作值張量沿著維度-2
應用 log-softmax 運算。 可以透過關閉make_log_softmax
關鍵字引數來停用此功能。- 關鍵字引數:
in_keys (字串的可迭代物件, 選用) – 從輸入 tensordict 讀取並傳遞到模組的鍵。 如果它包含多個元素,則值將按照 in_keys 可迭代物件給定的順序傳遞。 預設為
["observation"]
。spec (TensorSpec, 選用) – 僅限關鍵字的引數。 輸出張量的規格。 如果模組輸出多個輸出張量,則 spec 表徵第一個輸出張量的空間。
safe (bool) – 僅限關鍵字的引數。 如果為
True
,則會根據輸入 spec 檢查輸出的值。 由於探索策略或數值下溢/溢位問題,可能會發生超出範圍的採樣。 如果此值超出範圍,則會使用TensorSpec.project
方法將其投影回所需的空間。 預設值為False
。var_nums (int, 選用) – 如果
action_space = "mult-one-hot"
,則此值表示每個動作元件的基數。support (torch.Tensor) – 動作值的支援。
action_space (str, optional) – 動作空間。必須為
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
其中之一。此參數與spec
互斥,因為spec
會限制動作空間。make_log_softmax (bool, optional) – 如果
True
且該模組不是torchrl.modules.DistributionalDQNnet
類型,則會在動作值張量的 -2 維度上應用 log-softmax 運算。action_value_key (str or tuple of str, optional) – 如果輸入模組是
tensordict.nn.TensorDictModuleBase
實例,則必須與其輸出鍵之一匹配。 否則,此字串表示輸出 tensordict 中動作值條目的名稱。action_mask_key (str or tuple of str, optional) – 代表動作遮罩的輸入鍵。預設為
"None"
(等同於沒有遮罩)。
範例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 >>> module = MLP(out_features=(nbins, 4), depth=2) >>> # let us make sure that the output is a log-softmax >>> module = TensorDictSequential( ... TensorDictModule(module, ["observation"], ["action_value"]), ... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]), ... ) >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor( ... module=module, ... spec=action_spec, ... support=torch.arange(nbins)) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)