SliceSamplerWithoutReplacement¶
- class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: Optional[int] = None, slice_len: Optional[int] = None, drop_last: bool = False, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False)[source]¶
沿著第一個維度取樣資料片段,給定開始和停止訊號,不放回。
這個類別用於靜態回放緩衝區或兩個回放緩衝區擴展之間。擴展回放緩衝區將重置取樣器,目前不允許連續不放回取樣。
- 關鍵字引數:
drop_last (bool, optional) – 如果
True
,則將捨棄最後一個不完整的樣本(如果有的話)。 如果False
,則將保留最後一個樣本。 預設值為False
。num_slices (int) – 要取樣的片段數量。批次大小必須大於或等於
num_slices
引數。與slice_len
互斥。slice_len (int) – 要取樣的片段長度。批次大小必須大於或等於
slice_len
引數,並且可以被它整除。與num_slices
互斥。end_key (NestedKey, optional) – 指示軌跡(或 episode)結束的鍵。預設值為
("next", "done")
。traj_key (NestedKey, optional) – 指示軌跡的鍵。預設值為
"episode"
(通常在 TorchRL 的資料集中使用)。ends (torch.Tensor, optional) – 包含執行結束訊號的 1d 布林張量。每當獲取
end_key
或traj_key
的成本很高,或者此訊號已準備好可用時使用。必須與cache_values=True
一起使用,並且不能與end_key
或traj_key
結合使用。trajectories (torch.Tensor, optional) – 包含執行 ID 的 1d 整數張量。每當獲取
end_key
或traj_key
的成本很高,或者此訊號已準備好可用時使用。必須與cache_values=True
一起使用,並且不能與end_key
或traj_key
結合使用。truncated_key (NestedKey, optional) – 如果不是
None
,則此引數指示截斷的訊號應寫入輸出資料的位置。 這用於向數值估算器指示提供的軌跡中斷的位置。 預設為("next", "truncated")
。 此功能僅適用於TensorDictReplayBuffer
實例(否則截斷的鍵會返回到sample()
方法返回的 info 字典中)。strict_length (bool, optional) – 如果
False
,則允許長度短於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。 如果True
,則會篩選掉短於要求的軌跡。 請注意,這可能會導致有效的 batch_size 短於要求的! 軌跡可以使用split_trajectories()
拆分。 預設為True
。shuffle (bool, optional) – 如果
False
,軌跡的順序不會被打亂。 預設為True
。compile (bool or dict of kwargs, optional) – 如果
True
,sample()
方法的瓶頸將使用compile()
進行編譯。 關鍵字引數也可以透過此引數傳遞給 torch.compile。 預設為False
。
注意
為了恢復儲存中的軌跡分割,
SliceSamplerWithoutReplacement
將首先嘗試在儲存中尋找traj_key
條目。 如果找不到,則會使用end_key
來重建 episodes。範例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1000), ... # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each ... sampler=SliceSamplerWithoutReplacement(num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> # since we want trajectories of 32 transitions but there are only 4 episodes to >>> # sample from, we only get 4 x 32 = 128 transitions in this batch >>> print("sample:", sample) >>> print("trajectories in sample", sample.get("episode").unique())
SliceSamplerWithoutReplacement
預設與 TorchRL 的大多數資料集相容,並允許使用者以類似 dataloader 的方式使用資料集範例
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSamplerWithoutReplacement >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, ... sampler=SliceSamplerWithoutReplacement(num_slices=num_slices)) >>> # the last sample is kept, since drop_last=False by default >>> for i, batch in enumerate(data): ... print(batch.get("episode").unique()) tensor([ 5, 6, 8, 11, 12, 14, 16, 17, 19, 24]) tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22]) tensor([ 0, 3, 4, 20, 23])