快捷方式

ModelBasedEnvBase

torchrl.envs.ModelBasedEnvBase(*args, **kwargs)[原始碼]

用於基於模型的 RL sota 實現的基本環境。

MBRL 演算法模型的封裝器。它旨在為世界模型提供一個環境框架(包括但不限於觀察、獎勵、完成狀態和安全約束模型),並表現為一個經典的環境。

這是其他環境的基底類別,不應直接使用。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Composite, Unbounded
>>> class MyMBEnv(ModelBasedEnvBase):
...     def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
...         super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
...         self.observation_spec = Composite(
...             hidden_observation=Unbounded((4,))
...         )
...         self.state_spec = Composite(
...             hidden_observation=Unbounded((4,)),
...         )
...         self.action_spec = Unbounded((1,))
...         self.reward_spec = Unbounded((1,))
...
...     def _reset(self, tensordict: TensorDict) -> TensorDict:
...         tensordict = TensorDict({},
...             batch_size=self.batch_size,
...             device=self.device,
...         )
...         tensordict = tensordict.update(self.state_spec.rand())
...         tensordict = tensordict.update(self.observation_spec.rand())
...         return tensordict
>>> # This environment is used as follows:
>>> import torch.nn as nn
>>> from torchrl.modules import MLP, WorldModelWrapper
>>> world_model = WorldModelWrapper(
...     TensorDictModule(
...         MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
...         in_keys=["hidden_observation", "action"],
...         out_keys=["hidden_observation"],
...     ),
...     TensorDictModule(
...         nn.Linear(4, 1),
...         in_keys=["hidden_observation"],
...         out_keys=["reward"],
...     ),
... )
>>> env = MyMBEnv(world_model)
>>> tensordict = env.rollout(max_steps=10)
>>> print(tensordict)
TensorDict(
    fields={
        action: Tensor(torch.Size([10, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([10, 1]), dtype=torch.bool),
        hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32),
        next: LazyStackedTensorDict(
            fields={
                hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False),
        reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
屬性
  • observation_spec (Composite):觀察的取樣規範;

  • action_spec (TensorSpec):動作的取樣規範;

  • reward_spec (TensorSpec):獎勵的取樣規範;

  • input_spec (Composite):輸入的取樣規範;

  • batch_size (torch.Size):env 使用的 batch_size。如果未設定,env 將接受所有 batch size 的 tensordict。

  • device (torch.device):env 輸入和輸出預期存在的設備

參數:
  • world_model (nn.Module) – 生成世界狀態及其對應獎勵的模型;

  • params (List[torch.Tensor], optional) – 世界模型的參數列表;

  • buffers (List[torch.Tensor], optional) – 世界模型的緩衝區列表;

  • device (torch.device, optional) – env 輸入和輸出預期存在的設備

  • dtype (torch.dtype, optional) – env 輸入和輸出的 dtype

  • batch_size (torch.Size, optional) – 實例中包含的環境數量

  • run_type_check (bool, optional) – 是否對 env 的步驟執行類型檢查

torchrl.envs.step(TensorDict -> TensorDict)

在環境中執行一步 (step)。

torchrl.envs.reset(TensorDict, optional -> TensorDict)

重置 (reset) 環境。

torchrl.envs.set_seed(int -> int)

設定環境的隨機種子 (seed)。

torchrl.envs.rand_step(TensorDict, optional -> TensorDict)

根據動作規範 (action spec) 執行隨機的步 (step)。

torchrl.envs.rollout(Callable, ... -> TensorDict)

使用給定的策略在環境中執行 rollout (如果沒有提供策略,則執行隨機步)。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學課程

取得適合初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源