快捷方式

ActorCriticOperator

class torchrl.modules.tensordict_module.ActorCriticOperator(*args, **kwargs)[原始碼]

Actor-critic 運算子。

此類別將 actor 和價值模型包裝在一起,它們共享一個共同的觀察嵌入網路

../../_images/aafig-79381044a8773741ed8c83c5de90ab4def5c10b2.svg

注意

對於返回動作和狀態價值的類似類別 \(V(s)\),請參閱 ActorValueOperator

為了方便工作流程,此類別帶有一個 get_policy_operator() 方法,該方法將同時返回一個具有專用功能的獨立 TDModule。get_critic_operator 將返回父物件,因為該值是根據策略輸出計算的。

參數:
  • common_operator (TensorDictModule) – 一個常見的運算子,它讀取觀察並產生一個隱藏變數

  • policy_operator (TensorDictModule) – 一個策略運算子,它讀取隱藏變數並返回一個動作

  • value_operator (TensorDictModule) – 一個價值運算子,它讀取隱藏變數並返回一個值

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.modules import ProbabilisticActor
>>> from torchrl.modules import  ValueOperator, TanhNormal, ActorCriticOperator, NormalParamExtractor, MLP
>>> module_hidden = torch.nn.Linear(4, 4)
>>> td_module_hidden = SafeModule(
...    module=module_hidden,
...    in_keys=["observation"],
...    out_keys=["hidden"],
...    )
>>> module_action = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
>>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"])
>>> td_module_action = ProbabilisticActor(
...    module=module_action,
...    in_keys=["loc", "scale"],
...    out_keys=["action"],
...    distribution_class=TanhNormal,
...    return_log_prob=True,
...    )
>>> module_value = MLP(in_features=8, out_features=1, num_cells=[])
>>> td_module_value = ValueOperator(
...    module=module_value,
...    in_keys=["hidden", "action"],
...    out_keys=["state_action_value"],
...    )
>>> td_module = ActorCriticOperator(td_module_hidden, td_module_action, td_module_value)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> td_clone = td_module(td.clone())
>>> print(td_clone)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_policy_operator()(td.clone())
>>> print(td_clone)  # no value
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_critic_operator()(td.clone())
>>> print(td_clone)  # no action
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
get_critic_operator() TensorDictModuleWrapper[原始碼]

返回一個獨立的 critic 網路運算子,該運算子將狀態-動作對應到一個 critic 估計值。

get_policy_head() SafeSequential[原始碼]

返回策略 head。

get_value_head() SafeSequential[原始碼]

返回價值 head。

get_value_operator() TensorDictModuleWrapper[原始碼]

返回一個獨立的價值網路運算子,該運算子將觀察對應到一個價值估計值。

文件

獲取 PyTorch 的全面開發者文檔

查看文檔

教程

獲取針對初學者和高級開發者的深入教程

查看教程

資源

查找開發資源並獲得問題解答

查看資源