快捷方式

from_pytree

class tensordict.from_pytree(pytree, *, batch_size: Optional[Size] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None)

將 pytree 轉換為 TensorDict 實例。

此方法旨在盡可能保留 pytree 的巢狀結構。

新增額外的非張量鍵以追蹤每個層級的身分,提供內建的 pytree-to-tensordict 雙射轉換 API。

目前接受的類別包括 lists、tuples、named tuples 和 dict。

注意

對於字典,非 NestedKey 鍵會單獨註冊為 NonTensorData 實例。

注意

可轉換為張量的類型(例如 int、float 或 np.ndarray)將轉換為 torch.Tensor 實例。 請注意,此轉換是滿射的:將 tensordict 轉換回 pytree 不會恢復原始類型。

範例

>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key
>>> class WeirdLookingClass:
...     pass
...
>>> weird_key = WeirdLookingClass()
>>> # Make a pytree with tuple, lists, dict and namedtuple
>>> pytree = (
...     [torch.randint(10, (3,)), torch.zeros(2)],
...     {
...         "tensor": torch.randn(
...             2,
...         ),
...         "td": TensorDict({"one": 1}),
...         weird_key: torch.randint(10, (2,)),
...         "list": [1, 2, 3],
...     },
...     {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
... )
>>> # Build a TensorDict from that pytree
>>> td = from_pytree(pytree)
>>> # Recover the pytree
>>> pytree_recon = td.to_pytree()
>>> # Check that the leaves match
>>> def check(v1, v2):
>>>     assert (v1 == v2).all()
>>>
>>> torch.utils._pytree.tree_map(check, pytree, pytree_recon)
>>> assert weird_key in pytree_recon[1]

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得解答

檢視資源