SliceSampler¶
- class torchrl.data.replay_buffers.SliceSampler(*, num_slices: Optional[int] = None, slice_len: Optional[int] = None, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, cache_values: bool = False, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: Union[bool, int, Tuple[bool | int, bool | int]] = False)[source]¶
沿著第一維度採樣資料切片,給定起始和停止訊號。
這個類別會透過替換方式採樣子軌跡。對於沒有替換的版本,請參閱
SliceSamplerWithoutReplacement
。- 關鍵字引數:
num_slices (int) – 要採樣的切片數量。批次大小必須大於或等於
num_slices
引數。與slice_len
互斥。slice_len (int) – 要採樣的切片長度。批次大小必須大於或等於
slice_len
引數,並且可以被它整除。與num_slices
互斥。end_key (NestedKey, optional) – 指示軌跡(或 episode)結束的鍵。預設為
("next", "done")
。traj_key (NestedKey, optional) – 指示軌跡的鍵。預設為
"episode"
(通常用於 TorchRL 中的資料集)。ends (torch.Tensor, optional) – 包含執行結束訊號的一維布林張量。每當獲取
end_key
或traj_key
的成本很高,或此訊號已準備好可用時使用。必須與cache_values=True
一起使用,並且不能與end_key
或traj_key
結合使用。如果提供,則假定儲存已達到容量,並且如果ends
張量的最後一個元素為False
,則相同的軌跡跨越結束和開始。trajectories (torch.Tensor, optional) – 包含執行 ID 的一維整數張量。每當獲取
end_key
或traj_key
的成本很高,或此訊號已準備好可用時使用。必須與cache_values=True
一起使用,並且不能與end_key
或traj_key
結合使用。如果提供,則假定儲存已達到容量,並且如果軌跡張量的最後一個元素與第一個元素相同,則相同的軌跡跨越結束和開始。cache_values (bool, optional) –
用於靜態資料集。將會快取軌跡的起始和結束訊號。即使在呼叫
extend
期間軌跡索引發生變化,也可以安全使用,因為此操作會清除快取。警告
如果取樣器與被另一個緩衝區擴展的儲存器一起使用,則
cache_values=True
將無法工作。 例如:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
警告
如果緩衝區在進程之間共享,並且一個進程負責寫入,另一個進程負責取樣,則
cache_values=True
將無法按預期工作,因為清除快取只能在本地完成。truncated_key (NestedKey, optional) – 如果不是
None
,則此參數指示截斷訊號應寫入輸出資料的位置。 這用於向值估算器指示提供的軌跡在哪裡中斷。預設值為("next", "truncated")
。此功能僅適用於TensorDictReplayBuffer
實例(否則,截斷的鍵將在sample()
方法傳回的資訊字典中傳回)。strict_length (bool, optional) – 如果
False
,則允許長度短於 slice_len (或 batch_size // num_slices) 的軌跡出現在批次中。如果True
,則會濾除短於要求的軌跡。請注意,這可能會導致有效的 batch_size 短於要求的批次大小!可以使用split_trajectories()
分割軌跡。預設值為True
。compile (bool or dict of kwargs, optional) – 如果
True
,sample()
方法的瓶頸將使用compile()
編譯。也可以使用此參數將關鍵字引數傳遞給 torch.compile。預設值為False
。span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,取樣的軌跡將跨越左側和/或右側。這意味著提供的元素可能比要求的少。布林值表示每個軌跡至少取樣一個元素。整數 i 表示為每個取樣的軌跡至少收集 slice_len - i 個樣本。使用元組可以精細地控制左側(儲存軌跡的開頭)和右側(儲存軌跡的結尾)的跨度。
注意
為了恢復儲存中的軌跡分割,
SliceSampler
將首先嘗試在儲存中找到traj_key
條目。如果找不到,將使用end_key
來重建 episodes。範例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> torch.manual_seed(0) >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1_000_000), ... sampler=SliceSampler(cache_values=True, num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> print("sample:", sample) >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32)
SliceSampler
與 TorchRL 的大多數資料集預設相容範例
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSampler >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) >>> for batch in data: ... batch = batch.reshape(num_slices, -1) ... break >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) check that each batch only has one episode: tensor([[19], [14], [ 8], [10], [13], [ 4], [ 2], [ 3], [22], [ 8]])