捷徑

CatFrames

class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, padding='same', padding_value=0, as_inverse=False, reset_key: Optional[NestedKey] = None, done_key: Optional[NestedKey] = None)[source]

將連續的觀察影格串連成單一張量。

例如,這可以解釋觀察到的特徵的移動/速度。 在“使用深度強化學習玩 Atari”中提出( https://arxiv.org/pdf/1312.5602.pdf)。

在轉換後的環境中使用時, CatFrames 是一個有狀態的類別,可以透過呼叫 reset() 方法將其重置為其原始狀態。此方法接受帶有 "_reset" 條目的 tensordicts,該條目指示要重置的緩衝區。

參數:
  • N (int) – 要串接的觀測數量。

  • dim (int) – 串接觀測值的維度。應為負數,以確保其與不同 batch_size 的環境相容。

  • in_keys (sequence of NestedKey, optional) – 指向要串接的影格的鍵。預設為 [“pixels”]。

  • out_keys (sequence of NestedKey, optional) – 指向輸出要寫入位置的鍵。預設為 in_keys 的值。

  • padding (str, optional) – 填充方法。選項為 "same""constant"。預設為 "same",即使用第一個值進行填充。

  • padding_value (float, optional) – 如果 padding="constant",則用於填充的值。預設為 0。

  • as_inverse (bool, optional) – 如果為 True,則轉換將作為反向轉換應用。預設為 False

  • reset_key (NestedKey, optional) – 要用作部分重置指示器的重置鍵。必須是唯一的。如果未提供,則預設為父環境的唯一重置鍵(如果只有一個),否則會引發例外。

  • done_key (NestedKey, optional) – 要用作部分完成指示器的完成鍵。必須是唯一的。如果未提供,則預設為 "done"

範例

>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv('Pendulum-v1'),
...     Compose(
...         UnsqueezeTransform(-1, in_keys=["observation"]),
...         CatFrames(N=4, dim=-1, in_keys=["observation"]),
...     )
... )
>>> print(env.rollout(3))

CatFrames 轉換也可以離線使用,以不同規模重現線上影格串接的效果(或為了限制記憶體消耗)。以下範例提供了完整的圖片,以及 torchrl.data.ReplayBuffer 的使用。

範例

>>> from torchrl.envs.utils import RandomPolicy        >>> from torchrl.envs import UnsqueezeTransform, CatFrames
>>> from torchrl.collectors import SyncDataCollector
>>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension
>>> env = TransformedEnv(
...     GymEnv("CartPole-v1", from_pixels=True),
...     Compose(
...         ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
...         Resize(in_keys=["pixels_trsf"], w=64, h=64),
...         GrayScale(in_keys=["pixels_trsf"]),
...         UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
...         CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
...     )
... )
>>> # we design a collector
>>> collector = SyncDataCollector(
...     env,
...     RandomPolicy(env.action_spec),
...     frames_per_batch=10,
...     total_frames=1000,
... )
>>> for data in collector:
...     print(data)
...     break
>>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here.
>>> # however, we need to point to both the pixel entry at the root and at the next levels:
>>> t = Compose(
...         ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]),
...         Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
...         GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
...         CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
... )
>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16)
>>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
>>> rb.add(data_exclude)
>>> s = rb.sample(1) # the buffer has only one element
>>> # let's check that our sample is the same as the batch collected during inference
>>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()

注意

CatFrames 目前僅支援根目錄中的 "done" 訊號。巢狀 done,例如在 MARL 設定中找到的那些,目前不支援。如果需要此功能,請在 TorchRL repo 上提出 issue。

forward(tensordict: TensorDictBase) TensorDictBase[source]

讀取輸入 tensordict,並針對選定的鍵,應用轉換。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

轉換觀察規格,使產生的規格與轉換映射匹配。

參數:

observation_spec (TensorSpec) – 轉換前的規格

回傳:

轉換後預期的規格

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源