DTypeCastTransform¶
- class torchrl.envs.transforms.DTypeCastTransform(dtype_in: dtype, dtype_out: dtype, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, out_keys_inv: Optional[Sequence[NestedKey]] = None)[source]¶
將選定鍵的 dtype 轉換為另一個 dtype。
取決於建構期間是否提供
in_keys
或in_keys_inv
,類別行為將會改變如果提供了 keys,則只有這些條目會從
dtype_in
轉換為dtype_out
條目;如果沒有提供 keys,且物件位於 transforms 的環境暫存器中,則 dtype 設定為
dtype_in
的輸入和輸出規格將分別用作 in_keys_inv / in_keys。如果沒有提供 keys,且物件在沒有環境的情況下使用,則
forward
/inverse
傳遞將掃描輸入 tensordict 中所有dtype_in
的值,並將它們映射到dtype_out
tensor。對於大型資料結構,這可能會影響效能,因為此掃描並非沒有代價。要轉換的 keys 將不會被快取。請注意,在這種情況下,不能傳遞 out_keys(或 out_keys_inv),因為處理 keys 的順序無法精確預期。
- 參數:
dtype_in (torch.dtype) – 輸入 dtype (來自環境)。
dtype_out (torch.dtype) – 輸出 dtype (用於模型訓練)。
in_keys (NestedKey 序列, optional) – 要在暴露給外部物件和函數之前轉換為
dtype_out
的dtype_in
keys 列表。out_keys (NestedKey 序列, optional) – 目的地 keys 列表。如果未提供,則預設為
in_keys
。in_keys_inv (NestedKey 序列, optional) – 要在傳遞給包含的 base_env 或儲存體之前轉換為
dtype_in
的dtype_out
keys 列表。out_keys_inv (NestedKey 序列, optional) – 反向轉換的目的地 keys 列表。如果未提供,則預設為
in_keys_inv
。
範例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DTypeCastTransform(torch.double, torch.float, in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float64
在“自動”模式下,所有 float64 條目都會被轉換
範例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DTypeCastTransform(torch.double, torch.float) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float32
當在未指定轉換 keys 的情況下建構環境時,相同的行為也是規則
範例
>>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64)) ... self.action_spec = Unbounded((), dtype=torch.float64) ... self.reward_spec = Unbounded((1,), dtype=torch.float64) ... self.done_spec = Unbounded((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): ... assert data["action"].dtype == torch.float64 ... reward = self.reward_spec.rand() ... done = torch.zeros((1,), dtype=torch.bool) ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float)) >>> assert env.action_spec.dtype == torch.float32 >>> assert env.observation_spec["obs"].dtype == torch.float32 >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> assert env.transform.in_keys == ["obs", "reward"] >>> assert env.transform.in_keys_inv == ["action"]
- forward(tensordict: TensorDictBase = None) TensorDictBase [source]¶
讀取輸入 tensordict,並對於選定的 keys,套用轉換。
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
轉換輸入規格,使產生的規格符合轉換映射。
- 參數:
input_spec (TensorSpec) – 轉換前的規格
- 傳回:
轉換後預期的規格
- transform_observation_spec(observation_spec)[source]¶
轉換觀察規格,使產生的規格符合轉換映射。
- 參數:
observation_spec (TensorSpec) – 轉換前的規格
- 傳回:
轉換後預期的規格
- transform_output_spec(output_spec: Composite) Composite [source]¶
轉換輸出規格,使產生的規格符合轉換映射。
這個方法通常應該保持不變。 應該使用
transform_observation_spec()
、transform_reward_spec()
和transformfull_done_spec()
來實現變更。 :param output_spec: 轉換前的規格 :type output_spec: TensorSpec- 傳回:
轉換後預期的規格