Flat2TED¶
- class torchrl.data.Flat2TED(done_key='done', shift_key='shift', is_full_key='is_full', done_keys=('done', 'truncated', 'terminated'), reward_keys=('reward',))[source]¶
一種儲存載入 hook,用於將扁平化的 TED 資料反序列化為 TED 格式。
- 參數:
done_key (NestedKey, optional) – 應讀取完成狀態的鍵。預設為
("next", "done")
。shift_key (NestedKey, optional) – 將寫入 shift 的鍵。預設為 “shift”。
is_full_key (NestedKey, optional) – 將寫入 is_full 屬性的鍵。預設為 “is_full”。
done_keys (Tuple[NestedKey], optional) – 一個巢狀鍵的 tuple,指示完成的條目。預設為 (“done”, “truncated”, “terminated”)
reward_keys (Tuple[NestedKey], optional) – 一個巢狀鍵的 tuple,指示獎勵條目。預設為 (“reward”,)
範例
>>> import tempfile >>> >>> from tensordict import TensorDict >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage, Flat2TED >>> from torchrl.envs import GymEnv >>> import torch >>> >>> env = GymEnv("CartPole-v1") >>> env.set_seed(0) >>> torch.manual_seed(0) >>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200) >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200)) >>> rb.register_save_hook(TED2Flat()) >>> with tempfile.TemporaryDirectory() as tmpdir: ... for i, data in enumerate(collector): ... rb.extend(data) ... rb.dumps(tmpdir) ... # load the data to represent it ... td = TensorDict.load(tmpdir + "/storage/") ... ... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200)) ... rb_load.register_load_hook(Flat2TED()) ... rb_load.load(tmpdir) ... print("storage after loading", rb_load[:]) ... assert (rb[:] == rb_load[:]).all() storage after loading TensorDict( fields={ action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False)