捷徑

DoubleToFloat

class torchrl.envs.transforms.DoubleToFloat(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, out_keys_inv: Optional[Sequence[NestedKey]] = None)[來源]

針對選定的鍵,將一個 dtype 轉換為另一個。

根據在建構期間是否提供了 in_keysin_keys_inv,此類別的行為會有所改變

  • 如果提供了鍵,則只有這些條目會從 float64 轉換為 float32 條目;

  • 如果沒有提供鍵,且物件位於轉換的環境登錄檔中,則 dtype 設定為 float64 的輸入和輸出規格將分別用作 in_keys_inv / in_keys。

  • 如果沒有提供鍵值,且物件在沒有環境的情況下使用,forward / inverse 傳遞程序將會掃描輸入的 tensordict,尋找所有 float64 值,並將它們映射到 float32 張量。對於大型資料結構,這可能會影響效能,因為此掃描並非沒有代價。要轉換的鍵值不會被快取。請注意,在這種情況下,無法傳遞 out_keys(或 out_keys_inv),因為處理鍵值的順序無法精確預測。

參數:
  • in_keys (NestedKey 的序列, optional) – 要轉換為 float 的雙精度鍵值列表,然後再暴露給外部物件和函式。

  • out_keys (NestedKey 的序列, optional) – 目的地鍵值列表。如果未提供,則預設為 in_keys

  • in_keys_inv (NestedKey 的序列, optional) – 要轉換為雙精度浮點數的浮點數鍵值列表,然後再傳遞給包含的 base_env 或儲存體。

  • out_keys_inv (NestedKey 的序列, optional) – 逆轉換的目的地鍵值列表。如果未提供,則預設為 in_keys_inv

範例

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DoubleToFloat(in_keys=["obs"])
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float64

在「自動」模式中,所有 float64 條目都會被轉換

範例

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DoubleToFloat()
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float32

當環境在沒有指定轉換鍵值的情況下構建時,也適用相同的行為

範例

>>> class MyEnv(EnvBase):
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64))
...         self.action_spec = Unbounded((), dtype=torch.float64)
...         self.reward_spec = Unbounded((1,), dtype=torch.float64)
...         self.done_spec = Unbounded((1,), dtype=torch.bool)
...     def _reset(self, data=None):
...         return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
...     def _step(self, data):
...         assert data["action"].dtype == torch.float64
...         reward = self.reward_spec.rand()
...         done = torch.zeros((1,), dtype=torch.bool)
...         obs = self.observation_spec.rand()
...         assert reward.dtype == torch.float64
...         assert obs["obs"].dtype == torch.float64
...         return obs.empty().set("next", obs.update({"reward": reward, "done": done}))
...     def _set_seed(self, seed):
...         pass
>>> env = TransformedEnv(MyEnv(), DoubleToFloat())
>>> assert env.action_spec.dtype == torch.float32
>>> assert env.observation_spec["obs"].dtype == torch.float32
>>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> assert env.transform.in_keys == ["obs", "reward"]
>>> assert env.transform.in_keys_inv == ["action"]

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源