捷徑

FlattenObservation

class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, allow_positive_dim: bool = False)[原始碼]

展平張量的相鄰維度。

參數:
  • first_dim (int) – 要展平的維度的第一個維度。

  • last_dim (int) – 要展平的維度的最後一個維度。

  • in_keys (NestedKey 序列, optional) – 要展平的條目。如果未提供,則假設為 ["pixels"]

  • out_keys (NestedKey 序列, optional) – 展平的觀測鍵。如果未提供,則假設為 in_keys

  • allow_positive_dim (bool, optional) – 若為 True,則接受正維度。FlattenObservation 會將這些維度映射到輸入張量的第 n 個特徵維度(即父環境批次大小後的第 n 個維度)。預設為 False,即不允許非負維度。

forward(tensordict: TensorDictBase) TensorDictBase

讀取輸入的 tensordict,並針對選定的鍵值套用轉換。

對於任何專屬於父環境的操作(例如 FrameSkip),請修改 _step 方法。只有在需要修改輸入的 tensordict 時,才應覆寫 _call()

_call() 將會被 TransformedEnv.step()TransformedEnv.reset() 呼叫。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]

轉換 observation spec,使產生的 spec 符合轉換映射。

參數:

observation_spec (TensorSpec) – 轉換前的 spec

回傳值:

轉換後預期的 spec

文件

取得 PyTorch 的完整開發者文件

查看文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得您問題的解答

查看資源