捷徑

tensorclass

@tensorclass 修飾器可協助您建立自訂類別,這些類別繼承自 TensorDict 的行為,同時能夠將可能的條目限制為預定義的集合,或為您的類別實作自訂方法。

TensorDict 類似,@tensorclass 支援巢狀、索引、重塑、項目賦值。它還支援張量運算,如 clonesqueezetorch.catsplit 等等。@tensorclass 允許非張量條目,但是所有張量運算都嚴格限制於張量屬性。

使用者需要為非張量資料實作其自訂方法。重要的是要注意 @tensorclass 不強制執行嚴格的類型匹配

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

如同 TensorDict 的情況,從 v0.4 開始,如果省略了批次大小,則認為它是空的。

如果提供了非空的批次大小,@tensorclass 支援索引。在內部,張量物件會被索引,但是非張量資料保持不變

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 也支援設定和重設屬性,甚至是巢狀物件。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支援對其內容的形狀和裝置進行多個 torch 運算,例如 stackcatreshapeto(device)。若要取得支援運算的完整列表,請查看 tensordict 文件。

這是一個範例

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

可以使用 memmap 方法來儲存 tensorclass 實例。儲存策略如下:張量資料將使用記憶體對應的張量儲存,並且可以使用 json 格式序列化的非張量資料將以這種方式儲存。其他資料類型將使用 save() 儲存,它依賴於 pickle

可以透過 load_memmap() 還原序列化 tensorclass。只要 tensorclass 在工作環境中可用,建立的實例將具有與儲存的實例相同的類型

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

邊緣案例

@tensorclass 支援等式和不等式運算符,甚至是巢狀物件。請注意,不會驗證非張量/元資料。這將傳回一個張量類別物件,其張量屬性的布林值和非張量屬性的 None 值

這是一個範例

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支援設定項目。但是,在設定項目時,會執行非張量/元資料的身份檢查,而不是相等性,以避免效能問題。使用者需要確保項目的非張量資料與物件匹配,以避免差異。

這是一個範例

使用不同的 non_tensor 資料設定項目時,將會拋出 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

即使 @tensorclass 支援像是 cat()stack() 的 torch 函式,非張量 / 元資料並不會被驗證。Torch 操作會在張量資料上執行,並且在返回輸出時,會採用第一個 tensor class 物件的非張量 / 元資料。使用者需要確保所有 tensor class 物件的列表具有相同的非張量資料,以避免差異。

這是一個範例

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 也支援預先配置,你可以將物件的屬性初始化為 None,然後再設定它們。請注意,在初始化時,內部會將 None 屬性儲存為非張量 / 元資料,而在重設時,會根據屬性值的類型,將其儲存為張量資料或非張量 / 元資料。

這是一個範例

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen])

用於建立 tensorclass 類別的裝飾器。

NonTensorData(data[, _metadata, ...])

NonTensorStack(*args, **kwargs)

LazyStackedTensorDict 的薄型包裝器,使非張量資料的堆疊更容易識別。

自動轉換

警告

自動轉換是一個實驗性功能,未來可能會發生變化。與 python<=3.9 的相容性有限。

@tensorclass 部分支援自動轉換作為一個實驗性功能。諸如 __setattr__updateupdate_from_dict 之類的方法將嘗試將類型註釋的條目轉換為所需的 TensorDict / tensorclass 實例(除非在下面詳述的情況下)。 例如,以下程式碼會將 td 字典轉換為 TensorDict,並將 tc 條目轉換為 MyClass 實例

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的類型註釋項目將與自動轉換不相容,但 tensorclass 中的其他項目將相容

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果類別中至少有一個項目使用 type0 | type1 語義進行註釋,則整個類別的自動轉換功能都會停用。由於 tensorclass 支援非張量葉節點,因此在這些情況下設定字典會導致將其設定為普通字典,而不是張量集合子類別 (TensorDicttensorclass)

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

自動轉換不適用於葉節點 (張量)。這樣做的原因是此功能與包含 type0 | type1 類型提示語義的類型註釋不相容,而這種語義非常普遍。允許自動轉換會導致非常相似的程式碼,如果類型註釋僅略有不同,則會有截然不同的行為。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源