• 文件 >
  • 切片、索引和遮罩
捷徑

切片、索引和遮罩

作者: Tom Begley

在本教學中,您將學習如何對 TensorDict 進行切片、索引和遮罩。

如教學 操作 TensorDict 的形狀中所討論的,當我們建立 TensorDict 時,我們指定一個 batch_size,它必須與 TensorDict 中所有條目的前導維度一致。由於我們保證所有條目都共享這些共同的維度,因此我們能夠以與索引 torch.Tensor 相同的方式,索引和遮罩批次維度。索引會沿著批次維度應用於 TensorDict 中的所有條目。

例如,給定一個具有兩個批次維度的 TensorDicttensordict[0] 會傳回一個具有相同結構的新 TensorDict,其值對應於原始 TensorDict 中每個條目的第一個「列」。

import torch
from tensordict import TensorDict

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

print(tensordict[0])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

相同的語法適用於常規張量。例如,如果我們想要刪除每個條目的第一列,我們可以按如下方式進行索引

print(tensordict[1:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 4]),
    device=None,
    is_shared=False)

我們可以同時索引多個維度

print(tensordict[:, 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

我們也可以使用 Ellipsis 來表示盡可能多的 :,以使選擇元組的長度與 tensordict.batch_dims 的長度相同。

print(tensordict[..., 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

使用索引設定值

一般來說,只要批次大小相容,tensordict[index] = new_tensordict 即可正常運作。

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.]])

遮罩

我們遮罩 TensorDict,就像我們遮罩張量一樣。

mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])
tensordict[mask]
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([6]),
    device=None,
    is_shared=False)

腳本的總執行時間:(0 分鐘 0.004 秒)

由 Sphinx-Gallery 產生的圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

獲取針對初學者和進階開發者的深度教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源