快捷方式

tensorclass

class tensordict.tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False)

一個用來建立 tensorclass 類別的裝飾器。

tensorclass 類別是專門的 dataclasses.dataclass() 實例,可以直接執行一些預先定義的 tensor 操作,例如索引、項目賦值、重塑、轉換為裝置或儲存等。

參數:
  • autocast (bool, optional) – 如果 True,則在設定引數時將強制執行指示的類型。 預設為 False

  • frozen (bool, optional) – 如果 True,則無法修改 tensorclass 的內容。 提供此引數是為了與 dataclass 相容,可以透過類別建構函式中的 lock 引數獲得類似的行為。 預設為 False

tensorclass 可以帶參數或不帶參數使用: .. rubric:: 範例

>>> @tensorclass
... class X:
...     y: torch.Tensor
>>> X(1).y
1
>>> @tensorclass(autocast=False)
... class X:
...     y: torch.Tensor
>>> X(1).y
1
>>> @tensorclass(autocast=True)
... class X:
...     y: torch.Tensor
>>> X(1).y
torch.tensor(1)

範例

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
也可以將 tensorclasses 實例彼此巢狀

範例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # although the data is stored as a TensorDict, the type hint helps us >>> # to appropriately cast the data to the right type >>> assert isinstance(nesting_data.nested, type(data))

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源