ActorCriticOperator¶
- class torchrl.modules.tensordict_module.ActorCriticOperator(*args, **kwargs)[原始碼]¶
Actor-critic 運算子。
此類別將 actor 和價值模型包裝在一起,它們共享一個共同的觀察嵌入網路
注意
對於返回動作和狀態價值的類似類別 \(V(s)\),請參閱
ActorValueOperator
。為了方便工作流程,此類別帶有一個 get_policy_operator() 方法,該方法將同時返回一個具有專用功能的獨立 TDModule。get_critic_operator 將返回父物件,因為該值是根據策略輸出計算的。
- 參數:
common_operator (TensorDictModule) – 一個常見的運算子,它讀取觀察並產生一個隱藏變數
policy_operator (TensorDictModule) – 一個策略運算子,它讀取隱藏變數並返回一個動作
value_operator (TensorDictModule) – 一個價值運算子,它讀取隱藏變數並返回一個值
範例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamExtractor, MLP >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) >>> module_action = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( ... module=module_action, ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) >>> module_value = MLP(in_features=8, out_features=1, num_cells=[]) >>> td_module_value = ValueOperator( ... module=module_value, ... in_keys=["hidden", "action"], ... out_keys=["state_action_value"], ... ) >>> td_module = ActorCriticOperator(td_module_hidden, td_module_action, td_module_value) >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> td_clone = td_module.get_critic_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- get_critic_operator() TensorDictModuleWrapper [原始碼]¶
返回一個獨立的 critic 網路運算子,該運算子將狀態-動作對應到一個 critic 估計值。
- get_policy_head() SafeSequential [原始碼]¶
返回策略 head。
- get_value_head() SafeSequential [原始碼]¶
返回價值 head。