快捷方式

ReplayBuffer

class torchrl.data.ReplayBuffer(*, storage: Storage | None = None, sampler: Sampler | None = None, writer: Writer | None = None, collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, transform: 'Transform' | None = None, batch_size: int | None = None, dim_extend: int | None = None, checkpointer: 'StorageCheckpointerBase' | None = None, generator: torch.Generator | None = None, shared: bool = False)[source]

一個通用的、可組合的回放緩衝區類別。

關鍵字參數:
  • storage (Storage, optional) – 要使用的儲存空間。如果未提供,將會建立預設的 ListStorage,其 max_size1_000

  • sampler (Sampler, optional) – 要使用的取樣器。如果未提供,將會使用預設的 RandomSampler

  • writer (Writer, optional) – 要使用的寫入器。如果未提供,將會使用預設的 RoundRobinWriter

  • collate_fn (callable, optional) – 將樣本列表合併以形成 Tensor(s)/輸出的 mini-batch。當從 map-style 資料集使用批次載入時使用。預設值將根據儲存類型決定。

  • pin_memory (bool) – 是否應在 rb 樣本上調用 pin_memory()。

  • prefetch (int, optional) – 使用多執行緒預先獲取的下一個批次的數量。預設為 None(無預先獲取)。

  • transform (Transform, optional) – 在調用 sample() 時要執行的轉換。要鏈式轉換,請使用 Compose 類別。轉換應與 tensordict.TensorDict 內容一起使用。如果回放緩衝區與 PyTree 結構一起使用,也可以傳遞通用可調用物件(請參閱下面的範例)。

  • batch_size (int, optional) –

    調用 sample() 時要使用的批次大小。 .. note

    The batch-size can be specified at construction time via the
    ``batch_size`` argument, or at sampling time. The former should
    be preferred whenever the batch-size is consistent across the
    experiment. If the batch-size is likely to change, it can be
    passed to the :meth:`~.sample` method. This option is
    incompatible with prefetching (since this requires to know the
    batch-size in advance) as well as with samplers that have a
    ``drop_last`` argument.
    

  • dim_extend (int, optional) –

    指定調用 extend() 時要考慮用於擴展的維度。預設值為 storage.ndim-1。當使用 dim_extend > 0 時,如果 storage 實例化時有 ndim 參數可用,建議使用它,以便讓 storages 知道資料是多維的,並在抽樣期間保持 storage 容量和批次大小概念的一致性。

    注意

    此參數對 add() 無效,因此當 add()extend() 都在程式碼中使用時,應謹慎使用。例如:

    >>> data = torch.zeros(3, 4)
    >>> rb = ReplayBuffer(
    ...     storage=LazyTensorStorage(10, ndim=2),
    ...     dim_extend=1)
    >>> # these two approaches are equivalent:
    >>> for d in data.unbind(1):
    ...     rb.add(d)
    >>> rb.extend(data)
    

  • generator (torch.Generator, optional) –

    用於抽樣的 generator。為 replay buffer 使用專用 generator 可以對 seeding 進行精細控制,例如,保持全域 seed 不同,但為分散式任務保持 RB seed 相同。預設值為 None (全域預設 generator)。

    警告

    到目前為止,generator 對 transforms 沒有任何影響。

  • shared (bool, optional) – 是否使用多進程共享 buffer。預設值為 False

範例

>>> import torch
>>>
>>> from torchrl.data import ReplayBuffer, ListStorage
>>>
>>> torch.manual_seed(0)
>>> rb = ReplayBuffer(
...     storage=ListStorage(max_size=1000),
...     batch_size=5,
... )
>>> # populate the replay buffer and get the item indices
>>> data = range(10)
>>> indices = rb.extend(data)
>>> # sample will return as many elements as specified in the constructor
>>> sample = rb.sample()
>>> print(sample)
tensor([4, 9, 3, 0, 3])
>>> # Passing the batch-size to the sample method overrides the one in the constructor
>>> sample = rb.sample(batch_size=3)
>>> print(sample)
tensor([9, 7, 3])
>>> # one cans sample using the ``sample`` method or iterate over the buffer
>>> for i, batch in enumerate(rb):
...     print(i, batch)
...     if i == 3:
...         break
0 tensor([7, 3, 1, 6, 6])
1 tensor([9, 8, 6, 6, 8])
2 tensor([4, 3, 6, 9, 1])
3 tensor([4, 4, 1, 9, 9])

Replay buffers 接受任何種類的資料。並非所有 storage 類型都能正常工作,因為有些只期望數值資料,但預設的 ListStorage 可以。

範例

>>> torch.manual_seed(0)
>>> buffer = ReplayBuffer(storage=ListStorage(100), collate_fn=lambda x: x)
>>> indices = buffer.extend(["a", 1, None])
>>> buffer.sample(3)
[None, 'a', None]

TensorStorageLazyMemmapStorageLazyTensorStorage 也適用於任何 PyTree 結構(PyTree 是一個任意深度的巢狀結構,由 dicts、lists 或 tuples 組成,其中 leaves 是 tensors),前提是它僅包含 tensor 資料。

範例

>>> from torch.utils._pytree import tree_map
>>> def transform(x):
...     # Zeros all the data in the pytree
...     return tree_map(lambda y: y * 0, x)
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform)
>>> data = {
...     "a": torch.randn(3),
...     "b": {"c": (torch.zeros(2), [torch.ones(1)])},
...     30: -torch.ones(()),
... }
>>> rb.add(data)
>>> # The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
>>> s = rb.sample(10)
>>> # let's check that our transform did its job:
>>> def assert0(x):
>>>     assert (x == 0).all()
>>> tree_map(assert0, s)
add(data: Any) int[source]

將單個元素新增到 replay buffer。

參數:

data (Any) – 要新增到 replay buffer 的資料

返回值:

資料在 replay buffer 中的索引位置。

append_transform(transform: Transform, *, invert: bool = False) ReplayBuffer[source]

將 transform 附加在末尾。

當呼叫 sample 時,transforms 會按順序套用。

參數:

transform (Transform) – 要附加的 transform

關鍵字參數:

invert (bool, optional) – 如果 True,則 transform 將被反轉(forward 呼叫將在寫入期間呼叫,inverse 呼叫將在讀取期間呼叫)。預設值為 False

範例

>>> rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4)
>>> data = TensorDict({"a": torch.zeros(10)}, [10])
>>> def t(data):
...     data += 1
...     return data
>>> rb.append_transform(t, invert=True)
>>> rb.extend(data)
>>> assert (data == 1).all()
dump(*args, **kwargs)[source]

的別名 dumps().

dumps(path)[source]

將 replay buffer 儲存到磁碟上的指定路徑。

參數:

path (Path or str) – 儲存 replay buffer 的路徑。

範例

>>> import tempfile
>>> import tqdm
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
>>> import torch
>>> from tensordict import TensorDict
>>> # Build and populate the replay buffer
>>> S = 1_000_000
>>> sampler = PrioritizedSampler(S, 1.1, 1.0)
>>> # sampler = RandomSampler()
>>> storage = LazyMemmapStorage(S)
>>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler)
>>>
>>> for _ in tqdm.tqdm(range(100)):
...     td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100])
...     rb.extend(td)
...     sample = rb.sample(32)
...     rb.update_tensordict_priority(sample)
>>> # save and load the buffer
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     rb.dumps(tmpdir)
...
...     sampler = PrioritizedSampler(S, 1.1, 1.0)
...     # sampler = RandomSampler()
...     storage = LazyMemmapStorage(S)
...     rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler)
...     rb_load.loads(tmpdir)
...     assert len(rb) == len(rb_load)
empty()[source]

清空 replay buffer 並將游標重設為 0。

extend(data: Sequence) Tensor[source]

使用可迭代物件中包含的一個或多個元素來擴充重播緩衝區。

如果存在,將會呼叫反向轉換 (inverse transforms)。

參數:

data (iterable) – 要新增至重播緩衝區的資料集合。

返回值:

新增至重播緩衝區的資料索引。

警告

extend() 在處理數值列表時可能會有不明確的簽章,這些數值列表應被解釋為 PyTree(在這種情況下,列表中的所有元素都將放置在儲存中儲存的 PyTree 的一個片段中)或一次新增一個值的數值列表。為了避免這種情況,TorchRL 對列表和元組做了明確的區分:元組將被視為 PyTree,列表(在根層級)將被解釋為一次新增到緩衝區的值的堆疊。對於 ListStorage 實例,只能提供未綁定的元素(沒有 PyTrees)。

insert_transform(index: int, transform: Transform, *, invert: bool = False) ReplayBuffer[source]

插入轉換 (transform)。

當呼叫 sample 時,轉換會依序執行。

參數:
  • index (int) – 插入轉換的位置。

  • transform (Transform) – 要附加的 transform

關鍵字參數:

invert (bool, optional) – 如果 True,則 transform 將被反轉(forward 呼叫將在寫入期間呼叫,inverse 呼叫將在讀取期間呼叫)。預設值為 False

load(*args, **kwargs)[source]

loads() 的別名。

loads(path)[source]

從給定的路徑載入重播緩衝區狀態。

緩衝區應具有匹配的元件,並使用 dumps() 儲存。

參數:

path (Pathstr) – 重播緩衝區儲存的路徑。

有關更多資訊,請參閱 dumps()

register_load_hook(hook: Callable[[Any], Any])[source]

為儲存註冊載入 Hook。

注意

Hook 目前在儲存重播緩衝區時不會序列化:每次建立緩衝區時,都必須手動重新初始化它們。

register_save_hook(hook: Callable[[Any], Any])[source]

為儲存註冊儲存 Hook。

注意

Hook 目前在儲存重播緩衝區時不會序列化:每次建立緩衝區時,都必須手動重新初始化它們。

sample(batch_size: Optional[int] = None, return_info: bool = False) Any[source]

從重播緩衝區中採樣一批資料。

使用 Sampler 對索引進行取樣,並從 Storage 檢索它們。

參數:
  • batch_size (int, optional) – 要收集的資料大小。如果未提供,此方法將取樣一個由 sampler 指示的 batch_size。

  • return_info (bool) – 是否返回資訊。如果為 True,則結果為一個 tuple (data, info)。如果為 False,則結果為 data。

返回值:

在回放緩衝區中選擇的一批資料。如果 return_info 標誌設定為 True,則為包含此批次資料和資訊的 tuple。

property sampler

回放緩衝區的 sampler。

sampler 必須是 Sampler 的一個實例。

save(*args, **kwargs)[原始碼]

的別名 dumps().

set_sampler(sampler: Sampler)[原始碼]

在回放緩衝區中設定一個新的 sampler,並返回先前的 sampler。

set_storage(storage: Storage, collate_fn: Optional[Callable] = None)[原始碼]

在回放緩衝區中設定一個新的 storage,並返回先前的 storage。

參數:
  • storage (Storage) – 緩衝區的新 storage。

  • collate_fn (callable, optional) – 如果提供,則 collate_fn 會設定為此值。否則,它會重設為預設值。

set_writer(writer: Writer)[原始碼]

在回放緩衝區中設定一個新的 writer,並返回先前的 writer。

property storage

回放緩衝區的 storage。

storage 必須是 Storage 的一個實例。

property write_count

到目前為止,透過 add 和 extend 在緩衝區中寫入的項目總數。

property writer

回放緩衝區的 writer。

writer 必須是 Writer 的一個實例。

文件

訪問 PyTorch 的全面開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並獲得您問題的解答

檢視資源