快捷方式

EnvCreator

class torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True)[source]

環境建立器類別。

EnvCreator 是一個通用的環境建立器類別,可以在多重處理環境中建立環境時替換 lambda 函式。如果子程序上建立的環境必須與主程序共享資訊(例如,對於 VecNorm 轉換),EnvCreator 將把共享記憶體中 tensordict 的指標傳遞給每個程序,以便所有程序都同步。

參數:
  • create_env_fn (callable) – 一個返回 EnvBase 實例的可呼叫物件。

  • create_env_kwargs (dict, optional) – 環境建立器的 kwargs。

  • share_memory (bool, optional) – 如果為 False,則環境產生的 tensordict 將不會放置在共享記憶體中。

範例

>>> # We create the same environment on 2 processes using VecNorm
>>> # and check that the discounted count of observations match on
>>> # both workers, even if one has not executed any step
>>> import time
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import VecNorm, TransformedEnv
>>> from torchrl.envs import EnvCreator
>>> from torch import multiprocessing as mp
>>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
>>> env_creator = EnvCreator(env_fn)
>>>
>>> def test_env1(env_creator):
...     env = env_creator()
...     tensordict = env.reset()
...     for _ in range(10):
...         env.rand_step(tensordict)
...         if tensordict.get(("next", "done")):
...             tensordict = env.reset(tensordict)
...     print("env 1: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> def test_env2(env_creator):
...     env = env_creator()
...     time.sleep(5)
...     print("env 2: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> if __name__ == "__main__":
...     ps = []
...     p1 = mp.Process(target=test_env1, args=(env_creator,))
...     p1.start()
...     ps.append(p1)
...     p2 = mp.Process(target=test_env2, args=(env_creator,))
...     p2.start()
...     ps.append(p1)
...     for p in ps:
...         p.join()
env 1:  tensor([11.9934])
env 2:  tensor([11.9934])

文件

取得 PyTorch 的全面開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源