ObservationNorm¶
- class torchrl.envs.transforms.ObservationNorm(loc: Optional[float, torch.Tensor] = None, scale: Optional[float, torch.Tensor] = None, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, standard_normal: bool = False, eps: float | None = None)[source]¶
觀測值的仿射轉換層。
根據以下公式正規化觀測值:
\[obs = obs * scale + loc\]- 參數:
loc (數值 或 張量) – 仿射轉換的位置
scale (數值 或 張量) – 仿射轉換的比例
in_keys (NestedKey 的序列, 選用) – 要正規化的條目。預設為 [“observation”, “pixels”]。所有條目將使用相同的值進行正規化:如果需要不同的行為(例如,像素和狀態的不同正規化),則應使用不同的
ObservationNorm
物件。out_keys (NestedKey 的序列, 選用) – 輸出條目。預設為 in_keys 的值。
in_keys_inv (NestedKey 的序列, 選用) – ObservationNorm 也支援反向轉換。只有在將金鑰列表提供給
in_keys_inv
時才會發生這種情況。如果未提供任何金鑰,則只會呼叫正向轉換。out_keys_inv (NestedKey 的序列, 選用) – 反向轉換的輸出條目。預設為 in_keys_inv 的值。
standard_normal (bool, 選用) –
如果
True
,則轉換將為\[obs = (obs-loc)/scale\]因為它是為了標準化而進行的。預設為 False。
eps (float, 選用) – 在
standard_normal
情況下,比例的 epsilon 增量。如果無法直接從比例 dtype 恢復,則預設為1e-6
。
範例
>>> torch.set_default_tensor_type(torch.DoubleTensor) >>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3) >>> td = TensorDict({'obs': r}, [100]) >>> transform = ObservationNorm( ... loc = td.get('obs').mean(0), ... scale = td.get('obs').std(0), ... in_keys=["obs"], ... standard_normal=True) >>> _ = transform(td) >>> print(torch.isclose(td.get('obs').mean(0), ... torch.zeros(3)).all()) tensor(True) >>> print(torch.isclose(td.get('next_obs').std(0), ... torch.ones(3)).all()) tensor(True)
正規化統計資訊可以自動計算: .. rubric:: 範例
>>> from torchrl.envs.libs.gym import GymEnv >>> torch.manual_seed(0) >>> env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(env, ObservationNorm(in_keys=["observation"])) >>> env.set_seed(0) >>> env.transform.init_stats(100) >>> print(env.transform.loc, env.transform.scale) tensor([-1.3752e+01, -6.5087e-03, 2.9294e-03], dtype=torch.float32) tensor([14.9636, 2.5608, 0.6408], dtype=torch.float32)
- init_stats(num_iter: int, reduce_dim: Union[int, Tuple[int]] = 0, cat_dim: Optional[int] = None, key: Optional[NestedKey] = None, keep_dims: Optional[Tuple[int]] = None) None [source]¶
初始化父環境的 loc 和 scale 統計資訊。
正規化常數應理想地使觀測統計資訊接近標準高斯分佈的統計資訊。此方法計算位置和比例張量,該張量將根據在高斯分佈上擬合的數據,以給定的步數,通過父環境隨機生成,來經驗性地計算高斯分佈的平均值和標準差。
- 參數:
num_iter (int) – 環境中要運行的隨機迭代次數。
reduce_dim (int 或 int 的元組, 選用) – 計算平均值和標準差的維度。預設為 0。
cat_dim (int, optional) – 將收集的批次沿此維度串聯。它必須等於 reduce_dim(如果 reduce_dim 是整數)或成為 reduce_dim 元組的一部分。預設值與 reduce_dim 相同。
key (NestedKey, optional) – 如果提供,將從結果 tensordict 中的該鍵檢索摘要統計資訊。否則,將使用
ObservationNorm.in_keys
中的第一個鍵。keep_dims (tuple of int, optional) – 在 loc 和 scale 中保留的維度。例如,當對最後兩個維度(而非第三個維度)進行 3D 張量正規化時,可能需要位置和縮放具有 [C, 1, 1] 的形狀。預設為 None。
- transform_input_spec(input_spec)[source]¶
轉換輸入規格,使產生的規格與轉換映射匹配。
- 參數:
input_spec (TensorSpec) – 轉換前的規格
- 回傳值:
轉換後預期的規格
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
轉換觀測規格,使產生的規格與轉換映射匹配。
- 參數:
observation_spec (TensorSpec) – 轉換前的規格
- 回傳值:
轉換後預期的規格