捷徑

split_trajectories

torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: tensordict._nestedkey.NestedKey | None = None, done_key: tensordict._nestedkey.NestedKey | None = None, as_nested: bool = False) TensorDictBase[原始碼]

用於軌跡分離的實用函式。

接受帶有關鍵字 traj_ids 的 tensordict,該關鍵字指示每個軌跡的 id。

從那裡,建立一個 B x T x … 零填充 tensordict,其中 B 個批次在最大持續時間 T 上

參數:

rollout_tensordict (TensorDictBase) – 一個沿著最後一個維度具有相鄰軌跡的 rollout。

關鍵字參數:
  • prefix (NestedKey, optional) – 用於讀取和寫入元資料的前綴,例如 "traj_ids"(每個軌跡的可選整數 id)和 "mask" 條目,指示哪些資料有效,哪些資料無效。如果輸入具有 "collector" 條目,則預設為 "collector",否則預設為 () (沒有前綴)。prefix 保留為舊版功能,最終將被棄用。盡可能優先使用 trajectory_keydone_key

  • trajectory_key (NestedKey, optional) – 指向軌跡 id 的鍵。取代 done_keyprefix。如果未提供,則預設為 (prefix, "traj_ids")

  • done_key (NestedKey, optional) – 指向 "done"" 訊號的鍵,如果無法直接恢復軌跡。預設為 "done"

  • as_nested (booltorch.layout, 選用) –

    是否將結果以巢狀張量 (nested tensors) 的形式返回。預設值為 False。如果提供了 torch.layout,它將用於構建巢狀張量,否則將使用預設的 layout。

    注意

    使用 split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key) 應該會得到與 as_nested=False 完全相同的結果。由於這是一個實驗性功能,並且依賴於 nested_tensors,其 API 可能在未來發生變化,因此我們將其設為一個可選功能。使用 as_nested=True 時,執行速度應該更快。

    注意

    提供 layout 讓使用者可以控制巢狀張量要與 torch.strided 還是 torch.jagged layout 一起使用。雖然前者在撰寫本文時的功能略多,但後者將是 PyTorch 團隊未來的主要關注點,因為它與 compile() 具有更好的相容性。

回傳:

一個新的 tensordict,其 leading dimension 對應於軌跡 (trajectory)。還會新增一個 "mask" 布林值條目,它共享 trajectory_key 前綴和 tensordict shape。它指示 tensordict 的有效元素,並且如果找不到 trajectory_key,還會新增一個 "traj_ids" 條目。

範例

>>> from tensordict import TensorDict
>>> import torch
>>> from torchrl.collectors.utils import split_trajectories
>>> obs = torch.cat([torch.arange(10), torch.arange(5)])
>>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)])
>>> done = torch.zeros(15, dtype=torch.bool)
>>> done[9] = True
>>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32),
...     torch.ones(5, dtype=torch.int32)])
>>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15])
>>> data_split = split_trajectories(data, done_key="done")
>>> print(data_split)
TensorDict(
    fields={
        mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 10]),
            device=None,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)
>>> # check that split_trajectories got the trajectories right with the done signal
>>> assert (data_split["traj_ids"] == data_split["trajectory"]).all()
>>> print(data_split["mask"])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False, False, False, False, False]])
>>> data_split = split_trajectories(data, trajectory_key="trajectory")
>>> print(data_split)
TensorDict(
    fields={
        mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 10]),
            device=None,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題的解答

檢視資源