from_namedtuple¶
- class tensordict.from_namedtuple(named_tuple, *, auto_batch_size: bool = False)¶
遞迴地將 namedtuple 轉換為 TensorDict。
- 關鍵字引數:
auto_batch_size (bool, optional) – 如果
True
,批次大小將自動計算。預設為False
。
範例
>>> from tensordict import TensorDict, from_namedtuple >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> nt = data.to_namedtuple() >>> print(nt) GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) >>> from_namedtuple(nt, auto_batch_size=True) TensorDict( fields={ a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)