快捷鍵

CatTensors

class torchrl.envs.transforms.CatTensors(in_keys: Optional[Sequence[NestedKey]] = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[source]

將多個鍵連接成一個張量。

如果多個鍵描述單一狀態(例如“observation_position”和“observation_velocity”),這特別有用

參數:
  • in_keys (NestedKey 的序列) – 要連接的鍵。 如果 None (或未提供),則會在第一次使用轉換時從父環境中檢索鍵。 只有在設定了父項時,此行為才會起作用。

  • out_key (NestedKey) – 結果張量的鍵。

  • dim (int, 選用) – 將發生串連的維度。 預設值為 -1

關鍵字參數:
  • del_keys (bool, optional) – 如果 True,輸入值將在串聯後被刪除。預設值為 True

  • unsqueeze_if_oor (bool, optional) – 如果 True,CatTensor 將檢查要串聯的 tensors 是否存在指定的維度。如果不存在,tensors 將沿該維度進行 unsqueeze 操作。預設值為 False

  • sort (bool, optional) – 如果 True,keys 將在轉換中進行排序。否則,將採用使用者提供的順序。預設為 True

範例

>>> transform = CatTensors(in_keys=["key1", "key2"])
>>> td = TensorDict({"key1": torch.zeros(1, 1),
...     "key2": torch.ones(1, 1)}, [1])
>>> _ = transform(td)
>>> print(td.get("observation_vector"))
tensor([[0., 1.]])
>>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True)
>>> td = TensorDict({"key1": torch.zeros(1),
...     "key2": torch.ones(1)}, [])
>>> _ = transform(td)
>>> print(td.get("observation_vector").shape)
torch.Size([2, 1])
forward(tensordict: TensorDictBase) TensorDictBase

讀取輸入 tensordict,並針對選定的 keys 應用轉換。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]

轉換 observation spec,使產生的 spec 與轉換映射匹配。

參數:

observation_spec (TensorSpec) – 轉換前的 spec

回傳值:

轉換後預期的 spec

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源