SignTransform¶
- class torchrl.envs.transforms.SignTransform(in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[source]¶
一個計算 TensorDict 值符號的轉換。
此轉換讀取
in_keys
和in_keys_inv
中的 tensors,計算其元素的符號,並將產生的符號 tensors 寫入out_keys
和out_keys_inv
。- 參數:
in_keys (NestedKeys 的 list) – 輸入條目 (讀取)
out_keys (NestedKeys 的 list) – 輸入條目 (寫入)
in_keys_inv (NestedKeys 的 list) – 在
inv()
呼叫期間輸入的條目 (讀取)。out_keys_inv (NestedKeys 的 list) – 在
inv()
呼叫期間輸入的條目 (寫入)。
範例
>>> from torchrl.envs import GymEnv, TransformedEnv, SignTransform >>> base_env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(base_env, SignTransform(in_keys=['observation'])) >>> r = env.rollout(100) >>> obs = r["observation"] >>> assert (torch.logical_or(torch.logical_or(obs == -1, obs == 1), obs == 0.0)).all()
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
轉換觀察規格,使結果規格與轉換對應匹配。
- 參數:
observation_spec (TensorSpec) – 轉換前的規格
- 回傳:
轉換後的預期規格
- transform_reward_spec(reward_spec: TensorSpec) TensorSpec [原始碼]¶
轉換獎勵規格 (reward spec),使產生的規格符合轉換映射 (transform mapping)。
- 參數:
reward_spec (TensorSpec) – 轉換前的規格
- 回傳:
轉換後的預期規格