from_pytree¶
- class tensordict.from_pytree(pytree, *, batch_size: Optional[Size] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None)¶
將 pytree 轉換為 TensorDict 實例。
此方法旨在盡可能保留 pytree 的巢狀結構。
新增額外的非張量鍵以追蹤每個層級的身分,提供內建的 pytree-to-tensordict 雙射轉換 API。
目前接受的類別包括 lists、tuples、named tuples 和 dict。
注意
對於字典,非 NestedKey 鍵會單獨註冊為
NonTensorData
實例。注意
可轉換為張量的類型(例如 int、float 或 np.ndarray)將轉換為 torch.Tensor 實例。 請注意,此轉換是滿射的:將 tensordict 轉換回 pytree 不會恢復原始類型。
範例
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]