快捷方式

SliceSampler

class torchrl.data.replay_buffers.SliceSampler(*, num_slices: Optional[int] = None, slice_len: Optional[int] = None, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, cache_values: bool = False, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: Union[bool, int, Tuple[bool | int, bool | int]] = False)[source]

沿著第一維度採樣資料切片,給定起始和停止訊號。

這個類別會透過替換方式採樣子軌跡。對於沒有替換的版本,請參閱 SliceSamplerWithoutReplacement

關鍵字引數:
  • num_slices (int) – 要採樣的切片數量。批次大小必須大於或等於 num_slices 引數。與 slice_len 互斥。

  • slice_len (int) – 要採樣的切片長度。批次大小必須大於或等於 slice_len 引數,並且可以被它整除。與 num_slices 互斥。

  • end_key (NestedKey, optional) – 指示軌跡(或 episode)結束的鍵。預設為 ("next", "done")

  • traj_key (NestedKey, optional) – 指示軌跡的鍵。預設為 "episode"(通常用於 TorchRL 中的資料集)。

  • ends (torch.Tensor, optional) – 包含執行結束訊號的一維布林張量。每當獲取 end_keytraj_key 的成本很高,或此訊號已準備好可用時使用。必須與 cache_values=True 一起使用,並且不能與 end_keytraj_key 結合使用。如果提供,則假定儲存已達到容量,並且如果 ends 張量的最後一個元素為 False,則相同的軌跡跨越結束和開始。

  • trajectories (torch.Tensor, optional) – 包含執行 ID 的一維整數張量。每當獲取 end_keytraj_key 的成本很高,或此訊號已準備好可用時使用。必須與 cache_values=True 一起使用,並且不能與 end_keytraj_key 結合使用。如果提供,則假定儲存已達到容量,並且如果軌跡張量的最後一個元素與第一個元素相同,則相同的軌跡跨越結束和開始。

  • cache_values (bool, optional) –

    用於靜態資料集。將會快取軌跡的起始和結束訊號。即使在呼叫 extend 期間軌跡索引發生變化,也可以安全使用,因為此操作會清除快取。

    警告

    如果取樣器與被另一個緩衝區擴展的儲存器一起使用,則 cache_values=True 將無法工作。 例如:

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    警告

    如果緩衝區在進程之間共享,並且一個進程負責寫入,另一個進程負責取樣,則 cache_values=True 將無法按預期工作,因為清除快取只能在本地完成。

  • truncated_key (NestedKey, optional) – 如果不是 None,則此參數指示截斷訊號應寫入輸出資料的位置。 這用於向值估算器指示提供的軌跡在哪裡中斷。預設值為 ("next", "truncated")。此功能僅適用於 TensorDictReplayBuffer 實例(否則,截斷的鍵將在 sample() 方法傳回的資訊字典中傳回)。

  • strict_length (bool, optional) – 如果 False,則允許長度短於 slice_len (或 batch_size // num_slices) 的軌跡出現在批次中。如果 True,則會濾除短於要求的軌跡。請注意,這可能會導致有效的 batch_size 短於要求的批次大小!可以使用 split_trajectories() 分割軌跡。預設值為 True

  • compile (bool or dict of kwargs, optional) – 如果 Truesample() 方法的瓶頸將使用 compile() 編譯。也可以使用此參數將關鍵字引數傳遞給 torch.compile。預設值為 False

  • span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,取樣的軌跡將跨越左側和/或右側。這意味著提供的元素可能比要求的少。布林值表示每個軌跡至少取樣一個元素。整數 i 表示為每個取樣的軌跡至少收集 slice_len - i 個樣本。使用元組可以精細地控制左側(儲存軌跡的開頭)和右側(儲存軌跡的結尾)的跨度。

注意

為了恢復儲存中的軌跡分割,SliceSampler 將首先嘗試在儲存中找到 traj_key 條目。如果找不到,將使用 end_key 來重建 episodes。

範例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1_000_000),
...     sampler=SliceSampler(cache_values=True, num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)

SliceSampler 與 TorchRL 的大多數資料集預設相容

範例

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
...     batch = batch.reshape(num_slices, -1)
...     break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
        [14],
        [ 8],
        [10],
        [13],
        [ 4],
        [ 2],
        [ 3],
        [22],
        [ 8]])

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源