捷徑

BatchSizeTransform

class torchrl.envs.transforms.BatchSizeTransform(*, batch_size: Optional[Size] = None, reshape_fn: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, reset_func: Optional[Callable[[TensorDictBase, TensorDictBase], TensorDictBase]] = None, env_kwarg: bool = False)[source]

用於修改環境批次大小的轉換。

此轉換有兩種不同的用法:它可以用於為非批次鎖定(例如,無狀態)環境設定批次大小,以使用資料收集器啟用資料收集。它也可以用於修改環境的批次大小(例如,擠壓、取消擠壓或調整形狀)。

此轉換修改環境批次大小以符合提供的批次大小。它期望父環境批次大小可擴展到提供的批次大小。

關鍵字參數:
  • batch_size (torch.Size等效, 可選) – 環境的新批次大小。 與 reshape_fn 互斥。

  • reshape_fn (callable, 可選) –

    一個用於修改環境批次大小的可呼叫物件。 與 batch_size 互斥。

    注意

    目前,支援涉及 reshapeflattenunflattensqueezeunsqueeze 的轉換。 如果需要另一個調整形狀操作,請在 TorchRL github 上提交功能請求。

  • reset_func (可呼叫物件 (callable), 選填 (optional)) – 一個產生重置 tensordict 的函式。簽章必須符合 Callable[[TensorDictBase, TensorDictBase], TensorDictBase],其中第一個輸入引數是傳遞給環境的選填 tensordict (在呼叫 reset() 時),第二個輸入引數是 TransformedEnv.base_env.reset 的輸出。如果 env_kwarg=True,它也可以支援選填的 env 關鍵字引數。

  • env_kwarg (布林值 (bool), 選填 (optional)) – 如果 Truereset_func 必須支援 env 關鍵字引數。預設值為 False。傳遞的 env 會是帶有其轉換的 env。

範例

>>> # Changing the batch-size with a function
>>> from torchrl.envs import GymEnv
>>> base_env = GymEnv("CartPole-v1")
>>> env = TransformedEnv(base_env, BatchSizeTransform(reshape_fn=lambda data: data.reshape(1, 1)))
>>> env.rollout(4)
>>> # Setting the shape of a stateless environment
>>> class MyEnv(EnvBase):
...     batch_locked = False
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = Composite(observation=Unbounded(3))
...         self.reward_spec = Unbounded(1)
...         self.action_spec = Unbounded(1)
...
...     def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
...         tensordict_batch_size = tensordict.batch_size if tensordict is not None else torch.Size([])
...         result = self.observation_spec.rand(tensordict_batch_size)
...         result.update(self.full_done_spec.zero(tensordict_batch_size))
...         return result
...
...     def _step(
...         self,
...         tensordict: TensorDictBase,
...     ) -> TensorDictBase:
...         result = self.observation_spec.rand(tensordict.batch_size)
...         result.update(self.full_done_spec.zero(tensordict.batch_size))
...         result.update(self.full_reward_spec.zero(tensordict.batch_size))
...         return result
...
...     def _set_seed(self, seed: Optional[int]):
...         pass
...
>>> env = TransformedEnv(MyEnv(), BatchSizeTransform([5]))
>>> assert env.batch_size == torch.Size([5])
>>> assert env.rollout(10).shape == torch.Size([5, 10])

reset_func 可以建立具有所需批次大小的 tensordict,允許細緻的重置呼叫

>>> def reset_func(tensordict, tensordict_reset, env):
...     result = env.observation_spec.rand()
...     result.update(env.full_done_spec.zero())
...     assert result.batch_size != torch.Size([])
...     return result
>>> env = TransformedEnv(MyEnv(), BatchSizeTransform([5], reset_func=reset_func, env_kwarg=True))
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 2]),
    device=None,
    is_shared=False)

此轉換可用於在資料收集器中部署非批次鎖定的環境

>>> from torchrl.collectors import SyncDataCollector
>>> collector = SyncDataCollector(env, lambda td: env.rand_action(td), frames_per_batch=10, total_frames=-1)
>>> for data in collector:
...     print(data)
...     break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([5, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 2]),
    device=None,
    is_shared=False)
>>> collector.shutdown()
forward(tensordict: TensorDictBase) TensorDictBase

讀取輸入 tensordict,並針對選定的鍵套用轉換。

transform_env_batch_size(batch_size: Size)[原始碼]

轉換父環境的批次大小。

transform_input_spec(input_spec: Composite) Composite[原始碼]

轉換輸入規格,使結果規格符合轉換對應。

參數:

input_spec (TensorSpec) – 轉換前的規格

回傳:

轉換後預期的規格

transform_output_spec(output_spec: Composite) Composite[原始碼]

轉換輸出規格,使結果規格符合轉換對應。

此方法通常應保持不變。 變更應使用 transform_observation_spec(), transform_reward_spec()transformfull_done_spec() 實作。:param output_spec: 轉換前的規格 :type output_spec: TensorSpec

回傳:

轉換後預期的規格

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題的解答

檢視資源