快捷方式

VecNorm

class torchrl.envs.transforms.VecNorm(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, shared_td: Optional[TensorDictBase] = None, lock: Optional[Lock] = None, decay: float = 0.9999, eps: float = 0.0001, shapes: Optional[List[Size]] = None)[來源]

torchrl 環境的移動平均歸一化層。

VecNorm 會追蹤資料集的摘要統計資訊,以便即時進行標準化。如果轉換處於「eval」模式,則不會更新執行中的統計資訊。

如果有多個程序執行類似的環境,則可以傳遞放置在共享記憶體中的 TensorDictBase 實例:如果是這樣,每次查詢正規化層時,它都會更新共享相同參考的所有程序的值。

若要在推論時使用 VecNorm,並避免使用新的觀察值更新值,則應將此層替換為 to_observation_norm()。這將提供 VecNorm 的靜態版本,當來源轉換更新時,該版本不會更新。若要取得 VecNorm 層的凍結副本,請參閱 frozen_copy()

參數:
  • in_keys (NestedKey 的序列, 選用) – 要更新的鍵。預設值:["observation", "reward"]

  • out_keys (NestedKey 的序列, 選用) – 目的地鍵。預設為 in_keys

  • shared_td (TensorDictBase, 選用) – 包含轉換鍵的共享 tensordict。

  • lock (mp.Lock) – 用於防止程序之間發生競爭狀況的鎖。預設值為 None(在初始化期間建立鎖)。

  • decay (數字, 選用) – 移動平均的衰減率。預設值:0.99

  • eps (數字, 選用) – 執行中標準差的下限(用於數值下溢)。預設值為 1e-4。

  • shapes (List[torch.Size], 選用) – 如果提供,則表示每個 in_keys 的形狀。其長度必須與 in_keys 的長度相符。每個形狀必須與相應條目的尾隨維度相符。否則,條目的特徵維度(即所有不屬於 tensordict 批次大小的維度)將被視為特徵維度。

範例

>>> from torchrl.envs.libs.gym import GymEnv
>>> t = VecNorm(decay=0.9)
>>> env = GymEnv("Pendulum-v0")
>>> env = TransformedEnv(env, t)
>>> tds = []
>>> for _ in range(1000):
...     td = env.rand_step()
...     if td.get("done"):
...         _ = env.reset()
...     tds += [td]
>>> tds = torch.stack(tds, 0)
>>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all())
tensor(True)
>>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all())
tensor(True)
static build_td_for_shared_vecnorm(env: EnvBase, keys: Optional[Sequence[str]] = None, memmap: bool = False) TensorDictBase[source]

建立一個用於跨程序標準化的共享 tensordict。

參數:
  • env (EnvBase) – 用於建立 tensordict 的範例環境

  • keys (NestedKey 的序列, 選用) – 必須標準化的鍵。預設值為 ["next", "reward"]

  • memmap (bool) – 如果 True,則產生的 tensordict 將被轉換為記憶體映射(使用 memmap_())。否則,tensordict 將放置在共享記憶體中。

回傳值:

共享記憶體中的記憶體,將傳送到每個程序。

範例

>>> from torch import multiprocessing as mp
>>> queue = mp.Queue()
>>> env = make_env()
>>> td_shared = VecNorm.build_td_for_shared_vecnorm(env,
...     ["next", "reward"])
>>> assert td_shared.is_shared()
>>> queue.put(td_shared)
>>> # on workers
>>> v = VecNorm(shared_td=queue.get())
>>> env = TransformedEnv(make_env(), v)
forward(tensordict: TensorDictBase) TensorDictBase

讀取輸入 tensordict,並針對選定的鍵,套用轉換。

freeze() VecNorm[source]

凍結 VecNorm,避免在呼叫時更新統計資訊。

請參閱 unfreeze()

frozen_copy()[source]

傳回 Transform 的副本,該副本會追蹤統計資訊,但不會更新它們。

get_extra_state() OrderedDict[source]

傳回要包含在模組 state_dict 中的任何額外狀態。

如果您需要儲存額外的狀態,請為您的模組實作此函式和對應的 set_extra_state()。在建立模組的 state_dict() 時會呼叫此函式。

請注意,額外的狀態應該是可 pickle 的,以確保 state_dict 的工作序列化。我們僅提供序列化 Tensor 的向後相容性保證;如果其他物件的序列化 pickle 格式發生變更,則可能會破壞向後相容性。

回傳值:

要儲存在模組 state_dict 中的任何額外狀態

回傳類型:

物件

property loc

傳回一個 TensorDict,其中包含用於仿射轉換的 loc。

property scale

傳回一個 TensorDict,其中包含用於仿射轉換的 scale。

set_extra_state(state: OrderedDict) None[source]

設定載入的 state_dict 中包含的額外狀態。

此函數從 load_state_dict() 呼叫,以處理在 state_dict 中找到的任何額外狀態。如果您需要在其 state_dict 中儲存額外狀態,請為您的模組實作此函數和對應的 get_extra_state()

參數:

state (dict) – 來自 state_dict 的額外狀態

property standard_normal

locscale 給定的仿射轉換是否遵循標準常態方程式。

ObservationNorm 的 standard_normal 屬性類似。

總是傳回 True

to_observation_norm() Union[Compose, ObservationNorm][source]

將 VecNorm 轉換為可在推論時使用的 ObservationNorm 類別。

可以使用 state_dict() API 更新 ObservationNorm 層。

範例

>>> from torchrl.envs import GymEnv, VecNorm
>>> vecnorm = VecNorm(in_keys=["observation"])
>>> train_env = GymEnv("CartPole-v1", device=None).append_transform(
...     vecnorm)
>>>
>>> r = train_env.rollout(4)
>>>
>>> eval_env = GymEnv("CartPole-v1").append_transform(
...     vecnorm.to_observation_norm())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
>>>
>>> r = train_env.rollout(4)
>>> # Update entries with state_dict
>>> eval_env.transform.load_state_dict(
...     vecnorm.to_observation_norm().state_dict())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

轉換 observation spec,使產生的 spec 符合轉換映射。

參數:

observation_spec (TensorSpec) – 轉換前的 spec

回傳值:

轉換後預期的 spec

unfreeze() VecNorm[source]

取消凍結 VecNorm。

請參閱 freeze()

文件

取得 PyTorch 的全面開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源