TensorStorage¶
- class torchrl.data.replay_buffers.TensorStorage(storage, max_size=None, *, device: device = 'cpu', ndim: int = 1)[來源]¶
張量和 TensorDict 的儲存空間。
- 參數:
storage (tensor 或 TensorDict) – 要使用的資料緩衝區。
max_size (int) – 儲存空間的大小,即緩衝區中儲存的最大元素數量。
- 關鍵字引數:
device (torch.device, optional) – 抽樣張量將被儲存和傳送的裝置。預設值為
torch.device("cpu")
。如果傳遞 “auto”,則會從傳遞的第一批資料自動收集裝置。預設情況下不啟用此功能,以避免資料錯誤地放置在 GPU 上,導致 OOM 問題。ndim (int, optional) – 測量儲存空間大小時要考慮的維度數量。例如,如果
ndim=1
,則形狀為[3, 4]
的儲存空間的容量為3
,如果ndim=2
,則容量為12
。預設值為1
。
範例
>>> data = TensorDict({ ... "some data": torch.randn(10, 11), ... ("some", "nested", "data"): torch.randn(10, 11, 12), ... }, batch_size=[10, 11]) >>> storage = TensorStorage(data) >>> len(storage) # only the first dimension is considered as indexable 10 >>> storage.get(0) TensorDict( fields={ some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False) >>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``
此類別也支援 tensorclass 資料。
範例
>>> from tensordict import tensorclass >>> @tensorclass ... class MyClass: ... foo: torch.Tensor ... bar: torch.Tensor >>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11]) >>> storage = TensorStorage(data) >>> storage.get(0) MyClass( bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([11]), device=None, is_shared=False)
- attach(buffer: Any) None ¶
此函數將取樣器附加到此儲存空間。
從此儲存空間讀取的緩衝區必須透過呼叫此方法包含為附加實體。這保證了當儲存空間中的資料發生變更時,即使儲存空間與其他緩衝區(例如,優先順序取樣器)共享,元件也會知道變更。
- 參數:
buffer – 從此儲存空間讀取的物件。