CatTensors¶
- class torchrl.envs.transforms.CatTensors(in_keys: Optional[Sequence[NestedKey]] = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[source]¶
將多個鍵連接成一個張量。
如果多個鍵描述單一狀態(例如“observation_position”和“observation_velocity”),這特別有用
- 參數:
in_keys (NestedKey 的序列) – 要連接的鍵。 如果 None (或未提供),則會在第一次使用轉換時從父環境中檢索鍵。 只有在設定了父項時,此行為才會起作用。
out_key (NestedKey) – 結果張量的鍵。
dim (int, 選用) – 將發生串連的維度。 預設值為
-1
。
- 關鍵字參數:
del_keys (bool, optional) – 如果
True
,輸入值將在串聯後被刪除。預設值為True
。unsqueeze_if_oor (bool, optional) – 如果
True
,CatTensor 將檢查要串聯的 tensors 是否存在指定的維度。如果不存在,tensors 將沿該維度進行 unsqueeze 操作。預設值為False
。sort (bool, optional) – 如果
True
,keys 將在轉換中進行排序。否則,將採用使用者提供的順序。預設為True
。
範例
>>> transform = CatTensors(in_keys=["key1", "key2"]) >>> td = TensorDict({"key1": torch.zeros(1, 1), ... "key2": torch.ones(1, 1)}, [1]) >>> _ = transform(td) >>> print(td.get("observation_vector")) tensor([[0., 1.]]) >>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True) >>> td = TensorDict({"key1": torch.zeros(1), ... "key2": torch.ones(1)}, []) >>> _ = transform(td) >>> print(td.get("observation_vector").shape) torch.Size([2, 1])
- forward(tensordict: TensorDictBase) TensorDictBase ¶
讀取輸入 tensordict,並針對選定的 keys 應用轉換。
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [原始碼]¶
轉換 observation spec,使產生的 spec 與轉換映射匹配。
- 參數:
observation_spec (TensorSpec) – 轉換前的 spec
- 回傳值:
轉換後預期的 spec