ModelBasedEnvBase¶
- torchrl.envs.ModelBasedEnvBase(*args, **kwargs)[原始碼]¶
用於基於模型的 RL sota 實現的基本環境。
MBRL 演算法模型的封裝器。它旨在為世界模型提供一個環境框架(包括但不限於觀察、獎勵、完成狀態和安全約束模型),並表現為一個經典的環境。
這是其他環境的基底類別,不應直接使用。
範例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import Composite, Unbounded >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = Composite( ... hidden_observation=Unbounded((4,)) ... ) ... self.state_spec = Composite( ... hidden_observation=Unbounded((4,)), ... ) ... self.action_spec = Unbounded((1,)) ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict({}, ... batch_size=self.batch_size, ... device=self.device, ... ) ... tensordict = tensordict.update(self.state_spec.rand()) ... tensordict = tensordict.update(self.observation_spec.rand()) ... return tensordict >>> # This environment is used as follows: >>> import torch.nn as nn >>> from torchrl.modules import MLP, WorldModelWrapper >>> world_model = WorldModelWrapper( ... TensorDictModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), ... TensorDictModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], ... ), ... ) >>> env = MyMBEnv(world_model) >>> tensordict = env.rollout(max_steps=10) >>> print(tensordict) TensorDict( fields={ action: Tensor(torch.Size([10, 1]), dtype=torch.float32), done: Tensor(torch.Size([10, 1]), dtype=torch.bool), hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32), next: LazyStackedTensorDict( fields={ hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)}, batch_size=torch.Size([10]), device=cpu, is_shared=False), reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)}, batch_size=torch.Size([10]), device=cpu, is_shared=False)
- 屬性
observation_spec (Composite):觀察的取樣規範;
action_spec (TensorSpec):動作的取樣規範;
reward_spec (TensorSpec):獎勵的取樣規範;
input_spec (Composite):輸入的取樣規範;
batch_size (torch.Size):env 使用的 batch_size。如果未設定,env 將接受所有 batch size 的 tensordict。
device (torch.device):env 輸入和輸出預期存在的設備
- 參數:
world_model (nn.Module) – 生成世界狀態及其對應獎勵的模型;
params (List[torch.Tensor], optional) – 世界模型的參數列表;
buffers (List[torch.Tensor], optional) – 世界模型的緩衝區列表;
device (torch.device, optional) – env 輸入和輸出預期存在的設備
dtype (torch.dtype, optional) – env 輸入和輸出的 dtype
batch_size (torch.Size, optional) – 實例中包含的環境數量
run_type_check (bool, optional) – 是否對 env 的步驟執行類型檢查
- torchrl.envs.step(TensorDict -> TensorDict)¶
在環境中執行一步 (step)。
- torchrl.envs.reset(TensorDict, optional -> TensorDict)¶
重置 (reset) 環境。
- torchrl.envs.set_seed(int -> int)¶
設定環境的隨機種子 (seed)。
- torchrl.envs.rand_step(TensorDict, optional -> TensorDict)¶
根據動作規範 (action spec) 執行隨機的步 (step)。
- torchrl.envs.rollout(Callable, ... -> TensorDict)¶
使用給定的策略在環境中執行 rollout (如果沒有提供策略,則執行隨機步)。