tensorclass¶
@tensorclass
修飾器可協助您建立自訂類別,這些類別繼承自 TensorDict
的行為,同時能夠將可能的條目限制為預定義的集合,或為您的類別實作自訂方法。
與 TensorDict
類似,@tensorclass
支援巢狀、索引、重塑、項目賦值。它還支援張量運算,如 clone
、squeeze
、torch.cat
、split
等等。@tensorclass
允許非張量條目,但是所有張量運算都嚴格限制於張量屬性。
使用者需要為非張量資料實作其自訂方法。重要的是要注意 @tensorclass
不強制執行嚴格的類型匹配
>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
... floatdata: torch.Tensor
... intdata: torch.Tensor
... non_tensordata: str
... nested: Optional[MyData] = None
...
... def check_nested(self):
... assert self.nested is not None
>>>
>>> data = MyData(
... floatdata=torch.randn(3, 4, 5),
... intdata=torch.randint(10, (3, 4, 1)),
... non_tensordata="test",
... batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='test',
nested=None,
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
>>> data.nested = MyData(
... floatdata = torch.randn(3, 4, 5),
... intdata=torch.randint(10, (3, 4, 1)),
... non_tensordata="nested_test",
... batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='test',
nested=MyData(
floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='nested_test',
nested=None,
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False),
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
如同 TensorDict
的情況,從 v0.4 開始,如果省略了批次大小,則認為它是空的。
如果提供了非空的批次大小,@tensorclass
支援索引。在內部,張量物件會被索引,但是非張量資料保持不變
>>> print("indexed:", data[:2])
indexed: MyData(
floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='test',
nested=MyData(
floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='nested_test',
nested=None,
batch_size=torch.Size([2, 4]),
device=None,
is_shared=False),
batch_size=torch.Size([2, 4]),
device=None,
is_shared=False)
@tensorclass
也支援設定和重設屬性,甚至是巢狀物件。
>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'
>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'
@tensorclass
支援對其內容的形狀和裝置進行多個 torch 運算,例如 stack、cat、reshape 或 to(device)。若要取得支援運算的完整列表,請查看 tensordict 文件。
這是一個範例
>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='test_changed',
nested=MyData(
floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='nested_test_changed',
nested=None,
batch_size=torch.Size([6, 4]),
device=None,
is_shared=False),
batch_size=torch.Size([6, 4]),
device=None,
is_shared=False)
序列化¶
可以使用 memmap 方法來儲存 tensorclass 實例。儲存策略如下:張量資料將使用記憶體對應的張量儲存,並且可以使用 json 格式序列化的非張量資料將以這種方式儲存。其他資料類型將使用 save()
儲存,它依賴於 pickle。
可以透過 load_memmap()
還原序列化 tensorclass。只要 tensorclass 在工作環境中可用,建立的實例將具有與儲存的實例相同的類型
>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))
邊緣案例¶
@tensorclass
支援等式和不等式運算符,甚至是巢狀物件。請注意,不會驗證非張量/元資料。這將傳回一個張量類別物件,其張量屬性的布林值和非張量屬性的 None 值
這是一個範例
>>> print(data == data2)
MyData(
floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
non_tensordata=None,
nested=MyData(
floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
non_tensordata=None,
nested=None,
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False),
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
@tensorclass
支援設定項目。但是,在設定項目時,會執行非張量/元資料的身份檢查,而不是相等性,以避免效能問題。使用者需要確保項目的非張量資料與物件匹配,以避免差異。
這是一個範例
使用不同的 non_tensor
資料設定項目時,將會拋出 UserWarning
>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours
即使 @tensorclass
支援像是 cat()
和 stack()
的 torch 函式,非張量 / 元資料並不會被驗證。Torch 操作會在張量資料上執行,並且在返回輸出時,會採用第一個 tensor class 物件的非張量 / 元資料。使用者需要確保所有 tensor class 物件的列表具有相同的非張量資料,以避免差異。
這是一個範例
>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='test',
nested=MyData(
floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
non_tensordata='nested_test',
nested=None,
batch_size=torch.Size([2, 3, 4]),
device=None,
is_shared=False),
batch_size=torch.Size([2, 3, 4]),
device=None,
is_shared=False)
@tensorclass
也支援預先配置,你可以將物件的屬性初始化為 None,然後再設定它們。請注意,在初始化時,內部會將 None
屬性儲存為非張量 / 元資料,而在重設時,會根據屬性值的類型,將其儲存為張量資料或非張量 / 元資料。
這是一個範例
>>> @tensorclass
... class MyClass:
... X: Any
... y: Any
>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
y='testing',
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
|
用於建立 |
|
|
|
LazyStackedTensorDict 的薄型包裝器,使非張量資料的堆疊更容易識別。 |
自動轉換¶
警告
自動轉換是一個實驗性功能,未來可能會發生變化。與 python<=3.9 的相容性有限。
@tensorclass
部分支援自動轉換作為一個實驗性功能。諸如 __setattr__
、update
、update_
和 from_dict
之類的方法將嘗試將類型註釋的條目轉換為所需的 TensorDict / tensorclass 實例(除非在下面詳述的情況下)。 例如,以下程式碼會將 td 字典轉換為 TensorDict
,並將 tc 條目轉換為 MyClass
實例
>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... td: TensorDict
... tc: MyClass
...
>>> obj = MyClass(
... tensor=torch.randn(()),
... td={"a": torch.randn(())},
... tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)
注意
包含 typing.Optional
或 typing.Union
的類型註釋項目將與自動轉換不相容,但 tensorclass 中的其他項目將相容
>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... tc_autocast: MyClass = None
... tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
... tensor=torch.randn(()),
... tc_autocast={"tensor": torch.randn(())},
... tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)
如果類別中至少有一個項目使用 type0 | type1
語義進行註釋,則整個類別的自動轉換功能都會停用。由於 tensorclass
支援非張量葉節點,因此在這些情況下設定字典會導致將其設定為普通字典,而不是張量集合子類別 (TensorDict
或 tensorclass
)
>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... td: TensorDict
... tc: MyClass | None
...
>>> obj = MyClass(
... tensor=torch.randn(()),
... td={"a": torch.randn(())},
... tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)
注意
自動轉換不適用於葉節點 (張量)。這樣做的原因是此功能與包含 type0 | type1
類型提示語義的類型註釋不相容,而這種語義非常普遍。允許自動轉換會導致非常相似的程式碼,如果類型註釋僅略有不同,則會有截然不同的行為。