PrioritizedSliceSampler¶
- class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', *, num_slices: Optional[int] = None, slice_len: Optional[int] = None, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, cache_values: bool = False, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: Union[bool, int, Tuple[bool | int, bool | int]] = False, max_priority_within_buffer: bool = False)[source]¶
使用優先採樣,沿著第一個維度採樣資料片段,給定開始和停止訊號。
- 此類別根據 Schaul, T.; Quan, J.; Antonoglou, I.; 和 Silver, D. 2015 年提出的優先權重,進行子軌跡的替換採樣。
優先經驗重播。(https://arxiv.org/abs/1511.05952)
如需更多資訊,請參閱
SliceSampler
和PrioritizedSampler
。警告
PrioritizedSliceSampler 將查看個別轉換的優先順序,並據此採樣起點。這表示,如果優先順序較低的轉換跟在另一個優先順序較高的轉換之後,它們也可能出現在樣本中,而如果優先順序較高的轉換更接近軌跡的末端,如果它們不能用作起點,則可能永遠不會被採樣。目前,使用者有責任使用
update_priority()
彙總軌跡項目的優先順序。- 參數:
alpha (float) – 指數 α 決定優先順序的使用程度,其中 α = 0 對應於均勻情況。
beta (float) – 重要性採樣負指數。
eps (float, optional) – 新增至優先順序的增量,以確保緩衝區不包含空優先順序。預設值為 1e-8。
reduction (str, optional) – 多維 Tensordict(即儲存的軌跡)的縮減方法。可以是「max」、「min」、「median」或「mean」之一。
- 關鍵字引數:
num_slices (int) – 要採樣的片段數量。批次大小必須大於或等於
num_slices
引數。與slice_len
互斥。slice_len (int) – 要採樣的片段長度。批次大小必須大於或等於
slice_len
引數,且可被其整除。與num_slices
互斥。end_key (NestedKey, optional) – 指示軌跡(或 episode)結束的鍵。預設值為
("next", "done")
。traj_key (NestedKey, optional) – 指示軌跡的鍵。預設值為
"episode"
(TorchRL 中資料集常用的鍵)。ends (torch.Tensor, optional) – 包含執行結束訊號的 1 維布林張量。在
end_key
或traj_key
取得成本高昂,或此訊號已準備就緒時使用。必須與cache_values=True
一起使用,且不能與end_key
或traj_key
結合使用。trajectories (torch.Tensor, optional) – 包含執行 ID 的 1 維整數張量。在
end_key
或traj_key
取得成本高昂,或此訊號已準備就緒時使用。必須與cache_values=True
一起使用,且不能與end_key
或traj_key
結合使用。cache_values (bool, optional) –
與靜態資料集一起使用。將快取軌跡的開始和結束訊號。即使軌跡索引在呼叫
extend
期間變更,也可以安全地使用此功能,因為此操作將清除快取。警告
如果取樣器與由另一個緩衝區擴充的儲存體一起使用,則
cache_values=True
將無法運作。例如>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
警告
如果緩衝區在進程之間共用,且一個進程負責寫入,而另一個進程負責取樣,則
cache_values=True
將無法如預期運作,因為清除快取只能在本機完成。truncated_key (NestedKey, optional) – 如果不是
None
,則此引數指示應在輸出資料中寫入截斷訊號的位置。這用於向值估計器指示提供的軌跡中斷的位置。預設值為("next", "truncated")
。此功能僅適用於TensorDictReplayBuffer
實例(否則截斷鍵會在sample()
方法傳回的資訊字典中傳回)。strict_length (bool, optional) – 如果
False
,則長度小於 slice_len(或 batch_size // num_slices)的軌跡將被允許出現在批次中。如果True
,則會篩選掉短於要求的軌跡。請注意,這可能會導致有效的 batch_size 短於要求的批次大小!可以使用split_trajectories()
分割軌跡。預設值為True
。compile (bool or dict of kwargs, optional) – 如果
True
,則sample()
方法的瓶頸將使用compile()
編譯。關鍵字引數也可以透過此 arg 傳遞至 torch.compile。預設值為False
。span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,則採樣的軌跡將跨越左側和/或右側。這表示提供的元素可能少於要求的元素。布林值表示每個軌跡至少採樣一個元素。整數 i 表示每個採樣的軌跡至少收集 slice_len - i 個樣本。使用元組可以精細控制左側(儲存軌跡的開頭)和右側(儲存軌跡的結尾)的跨度。
max_priority_within_buffer (bool, optional) – 如果
True
,則在緩衝區內追蹤最大優先順序。當False
時,最大優先順序會追蹤自取樣器實例化以來的最大值。預設值為False
。
範例
>>> import torch >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler >>> from tensordict import TensorDict >>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6) >>> data = TensorDict( ... { ... "observation": torch.randn(9,16), ... "action": torch.randn(9, 1), ... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long), ... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long), ... ("next", "observation"): torch.randn(9,16), ... ("next", "reward"): torch.randn(9,1), ... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1), ... }, ... batch_size=[9], ... ) >>> rb.extend(data) >>> sample, info = rb.sample(return_info=True) >>> print("episode", sample["episode"].tolist()) episode [2, 2, 2, 2, 1, 1] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 1, 2] >>> print("weight", info["_weight"].tolist()) weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1]) >>> rb.update_priority(torch.arange(0,9,1), priority=priority) >>> sample, info = rb.sample(return_info=True) >>> print("episode", sample["episode"].tolist()) episode [2, 2, 2, 2, 2, 2] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 0, 1] >>> print("weight", info["_weight"].tolist()) weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
- update_priority(index: Union[int, Tensor], priority: Union[float, Tensor], *, storage: torchrl.data.replay_buffers.storages.TensorStorage | None = None) None ¶
更新索引所指向資料的優先順序。
- 參數:
index (int or torch.Tensor) – 要更新的優先順序索引。
priority (Number or torch.Tensor) – 已索引元素的新優先順序。
- 關鍵字引數:
storage (Storage, optional) – 用於將 Nd 索引大小對應至 sum_tree 和 min_tree 的 1 維大小的儲存體。僅在
index.ndim > 2
時才需要。