FlattenObservation¶
- class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, allow_positive_dim: bool = False)[原始碼]¶
展平張量的相鄰維度。
- 參數:
first_dim (int) – 要展平的維度的第一個維度。
last_dim (int) – 要展平的維度的最後一個維度。
in_keys (NestedKey 序列, optional) – 要展平的條目。如果未提供,則假設為
["pixels"]
。out_keys (NestedKey 序列, optional) – 展平的觀測鍵。如果未提供,則假設為
in_keys
。allow_positive_dim (bool, optional) – 若為
True
,則接受正維度。FlattenObservation
會將這些維度映射到輸入張量的第 n 個特徵維度(即父環境批次大小後的第 n 個維度)。預設為 False,即不允許非負維度。
- forward(tensordict: TensorDictBase) TensorDictBase ¶
讀取輸入的 tensordict,並針對選定的鍵值套用轉換。
對於任何專屬於父環境的操作(例如 FrameSkip),請修改 _step 方法。只有在需要修改輸入的 tensordict 時,才應覆寫
_call()
。_call()
將會被TransformedEnv.step()
和TransformedEnv.reset()
呼叫。
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [原始碼]¶
轉換 observation spec,使產生的 spec 符合轉換映射。
- 參數:
observation_spec (TensorSpec) – 轉換前的 spec
- 回傳值:
轉換後預期的 spec