捷徑

ActorCriticWrapper

class torchrl.modules.tensordict_module.ActorCriticWrapper(*args, **kwargs)[來源]

沒有共同模組的 Actor-value 運算子。

此類別將一個 actor 和一個 value 模型包裝在一起,它們不共享一個共同的觀察嵌入網路。

../../_images/aafig-5b1c51d6da7f2229a6c42592c838f793bf136146.svg

為了方便工作流程,此類別提供 get_policy_operator() 和 get_value_operator() 方法,它們都將傳回一個具有專用功能的獨立 TDModule。

參數:
  • policy_operator (TensorDictModule) – 一個 policy 運算子,讀取隱藏變數並傳回一個動作。

  • value_operator (TensorDictModule) – 一個 value 運算子,讀取隱藏變數並傳回一個值。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import (
...      ActorCriticWrapper,
...      ProbabilisticActor,
...      NormalParamExtractor,
...      TanhNormal,
...      ValueOperator,
...  )
>>> action_module = TensorDictModule(
...        nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()),
...        in_keys=["observation"],
...        out_keys=["loc", "scale"],
...    )
>>> td_module_action = ProbabilisticActor(
...    module=action_module,
...    in_keys=["loc", "scale"],
...    distribution_class=TanhNormal,
...    return_log_prob=True,
...    )
>>> module_value = torch.nn.Linear(4, 1)
>>> td_module_value = ValueOperator(
...    module=module_value,
...    in_keys=["observation"],
...    )
>>> td_module = ActorCriticWrapper(td_module_action, td_module_value)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> td_clone = td_module(td.clone())
>>> print(td_clone)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_policy_operator()(td.clone())
>>> print(td_clone)  # no value
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_value_operator()(td.clone())
>>> print(td_clone)  # no action
TensorDict(
    fields={
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
get_policy_head() SafeSequential

傳回一個獨立的 policy 運算子,它將觀察對應到一個動作。

get_policy_operator() SafeSequential[來源]

傳回一個獨立的 policy 運算子,它將觀察對應到一個動作。

get_value_head() SafeSequential

傳回一個獨立的 value 網路運算子,它將觀察對應到一個價值估計。

get_value_operator() SafeSequential[原始碼]

傳回一個獨立的 value 網路運算子,它將觀察對應到一個價值估計。

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並取得問題解答

查看資源