捷徑

SliceSamplerWithoutReplacement

class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: Optional[int] = None, slice_len: Optional[int] = None, drop_last: bool = False, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False)[source]

沿著第一個維度取樣資料片段,給定開始和停止訊號,不放回。

這個類別用於靜態回放緩衝區或兩個回放緩衝區擴展之間。擴展回放緩衝區將重置取樣器,目前不允許連續不放回取樣。

關鍵字引數:
  • drop_last (bool, optional) – 如果 True,則將捨棄最後一個不完整的樣本(如果有的話)。 如果 False,則將保留最後一個樣本。 預設值為 False

  • num_slices (int) – 要取樣的片段數量。批次大小必須大於或等於 num_slices 引數。與 slice_len 互斥。

  • slice_len (int) – 要取樣的片段長度。批次大小必須大於或等於 slice_len 引數,並且可以被它整除。與 num_slices 互斥。

  • end_key (NestedKey, optional) – 指示軌跡(或 episode)結束的鍵。預設值為 ("next", "done")

  • traj_key (NestedKey, optional) – 指示軌跡的鍵。預設值為 "episode"(通常在 TorchRL 的資料集中使用)。

  • ends (torch.Tensor, optional) – 包含執行結束訊號的 1d 布林張量。每當獲取 end_keytraj_key 的成本很高,或者此訊號已準備好可用時使用。必須與 cache_values=True 一起使用,並且不能與 end_keytraj_key 結合使用。

  • trajectories (torch.Tensor, optional) – 包含執行 ID 的 1d 整數張量。每當獲取 end_keytraj_key 的成本很高,或者此訊號已準備好可用時使用。必須與 cache_values=True 一起使用,並且不能與 end_keytraj_key 結合使用。

  • truncated_key (NestedKey, optional) – 如果不是 None,則此引數指示截斷的訊號應寫入輸出資料的位置。 這用於向數值估算器指示提供的軌跡中斷的位置。 預設為 ("next", "truncated")。 此功能僅適用於 TensorDictReplayBuffer 實例(否則截斷的鍵會返回到 sample() 方法返回的 info 字典中)。

  • strict_length (bool, optional) – 如果 False,則允許長度短於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。 如果 True,則會篩選掉短於要求的軌跡。 請注意,這可能會導致有效的 batch_size 短於要求的! 軌跡可以使用 split_trajectories() 拆分。 預設為 True

  • shuffle (bool, optional) – 如果 False,軌跡的順序不會被打亂。 預設為 True

  • compile (bool or dict of kwargs, optional) – 如果 Truesample() 方法的瓶頸將使用 compile() 進行編譯。 關鍵字引數也可以透過此引數傳遞給 torch.compile。 預設為 False

注意

為了恢復儲存中的軌跡分割,SliceSamplerWithoutReplacement 將首先嘗試在儲存中尋找 traj_key 條目。 如果找不到,則會使用 end_key 來重建 episodes。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement
>>>
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1000),
...     # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each
...     sampler=SliceSamplerWithoutReplacement(num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> # since we want trajectories of 32 transitions but there are only 4 episodes to
>>> # sample from, we only get 4 x 32 = 128 transitions in this batch
>>> print("sample:", sample)
>>> print("trajectories in sample", sample.get("episode").unique())

SliceSamplerWithoutReplacement 預設與 TorchRL 的大多數資料集相容,並允許使用者以類似 dataloader 的方式使用資料集

範例

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSamplerWithoutReplacement
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320,
...     sampler=SliceSamplerWithoutReplacement(num_slices=num_slices))
>>> # the last sample is kept, since drop_last=False by default
>>> for i, batch in enumerate(data):
...     print(batch.get("episode").unique())
tensor([ 5,  6,  8, 11, 12, 14, 16, 17, 19, 24])
tensor([ 1,  2,  7,  9, 10, 13, 15, 18, 21, 22])
tensor([ 0,  3,  4, 20, 23])

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源