EnvCreator¶
- class torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True)[source]¶
環境建立器類別。
EnvCreator 是一個通用的環境建立器類別,可以在多重處理環境中建立環境時替換 lambda 函式。如果子程序上建立的環境必須與主程序共享資訊(例如,對於 VecNorm 轉換),EnvCreator 將把共享記憶體中 tensordict 的指標傳遞給每個程序,以便所有程序都同步。
- 參數:
create_env_fn (callable) – 一個返回 EnvBase 實例的可呼叫物件。
create_env_kwargs (dict, optional) – 環境建立器的 kwargs。
share_memory (bool, optional) – 如果為 False,則環境產生的 tensordict 將不會放置在共享記憶體中。
範例
>>> # We create the same environment on 2 processes using VecNorm >>> # and check that the discounted count of observations match on >>> # both workers, even if one has not executed any step >>> import time >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import VecNorm, TransformedEnv >>> from torchrl.envs import EnvCreator >>> from torch import multiprocessing as mp >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): ... env = env_creator() ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) ... if tensordict.get(("next", "done")): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): ... env = env_creator() ... time.sleep(5) ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": ... ps = [] ... p1 = mp.Process(target=test_env1, args=(env_creator,)) ... p1.start() ... ps.append(p1) ... p2 = mp.Process(target=test_env2, args=(env_creator,)) ... p2.start() ... ps.append(p1) ... for p in ps: ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934])