快捷方式

ObservationNorm

class torchrl.envs.transforms.ObservationNorm(loc: Optional[float, torch.Tensor] = None, scale: Optional[float, torch.Tensor] = None, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, standard_normal: bool = False, eps: float | None = None)[source]

觀測值的仿射轉換層。

根據以下公式正規化觀測值:

\[obs = obs * scale + loc\]
參數:
  • loc (數值張量) – 仿射轉換的位置

  • scale (數值張量) – 仿射轉換的比例

  • in_keys (NestedKey 的序列, 選用) – 要正規化的條目。預設為 [“observation”, “pixels”]。所有條目將使用相同的值進行正規化:如果需要不同的行為(例如,像素和狀態的不同正規化),則應使用不同的 ObservationNorm 物件。

  • out_keys (NestedKey 的序列, 選用) – 輸出條目。預設為 in_keys 的值。

  • in_keys_inv (NestedKey 的序列, 選用) – ObservationNorm 也支援反向轉換。只有在將金鑰列表提供給 in_keys_inv 時才會發生這種情況。如果未提供任何金鑰,則只會呼叫正向轉換。

  • out_keys_inv (NestedKey 的序列, 選用) – 反向轉換的輸出條目。預設為 in_keys_inv 的值。

  • standard_normal (bool, 選用) –

    如果 True,則轉換將為

    \[obs = (obs-loc)/scale\]

    因為它是為了標準化而進行的。預設為 False

  • eps (float, 選用) – 在 standard_normal 情況下,比例的 epsilon 增量。如果無法直接從比例 dtype 恢復,則預設為 1e-6

範例

>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3)
>>> td = TensorDict({'obs': r}, [100])
>>> transform = ObservationNorm(
...     loc = td.get('obs').mean(0),
...     scale = td.get('obs').std(0),
...     in_keys=["obs"],
...     standard_normal=True)
>>> _ = transform(td)
>>> print(torch.isclose(td.get('obs').mean(0),
...     torch.zeros(3)).all())
tensor(True)
>>> print(torch.isclose(td.get('next_obs').std(0),
...     torch.ones(3)).all())
tensor(True)

正規化統計資訊可以自動計算: .. rubric:: 範例

>>> from torchrl.envs.libs.gym import GymEnv
>>> torch.manual_seed(0)
>>> env = GymEnv("Pendulum-v1")
>>> env = TransformedEnv(env, ObservationNorm(in_keys=["observation"]))
>>> env.set_seed(0)
>>> env.transform.init_stats(100)
>>> print(env.transform.loc, env.transform.scale)
tensor([-1.3752e+01, -6.5087e-03,  2.9294e-03], dtype=torch.float32) tensor([14.9636,  2.5608,  0.6408], dtype=torch.float32)
init_stats(num_iter: int, reduce_dim: Union[int, Tuple[int]] = 0, cat_dim: Optional[int] = None, key: Optional[NestedKey] = None, keep_dims: Optional[Tuple[int]] = None) None[source]

初始化父環境的 loc 和 scale 統計資訊。

正規化常數應理想地使觀測統計資訊接近標準高斯分佈的統計資訊。此方法計算位置和比例張量,該張量將根據在高斯分佈上擬合的數據,以給定的步數,通過父環境隨機生成,來經驗性地計算高斯分佈的平均值和標準差。

參數:
  • num_iter (int) – 環境中要運行的隨機迭代次數。

  • reduce_dim (intint 的元組, 選用) – 計算平均值和標準差的維度。預設為 0。

  • cat_dim (int, optional) – 將收集的批次沿此維度串聯。它必須等於 reduce_dim(如果 reduce_dim 是整數)或成為 reduce_dim 元組的一部分。預設值與 reduce_dim 相同。

  • key (NestedKey, optional) – 如果提供,將從結果 tensordict 中的該鍵檢索摘要統計資訊。否則,將使用 ObservationNorm.in_keys 中的第一個鍵。

  • keep_dims (tuple of int, optional) – 在 loc 和 scale 中保留的維度。例如,當對最後兩個維度(而非第三個維度)進行 3D 張量正規化時,可能需要位置和縮放具有 [C, 1, 1] 的形狀。預設為 None。

transform_input_spec(input_spec)[source]

轉換輸入規格,使產生的規格與轉換映射匹配。

參數:

input_spec (TensorSpec) – 轉換前的規格

回傳值:

轉換後預期的規格

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

轉換觀測規格,使產生的規格與轉換映射匹配。

參數:

observation_spec (TensorSpec) – 轉換前的規格

回傳值:

轉換後預期的規格

文件

獲取 PyTorch 的全面開發者文件

檢視文件

教學

獲取針對初學者和高級開發人員的深入教學

檢視教學

資源

查找開發資源並獲得您問題的解答

檢視資源