from_dict¶
- class tensordict.from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None)¶
傳回從字典或另一個
TensorDict
建立的 TensorDict。如果未指定
batch_size
,則傳回可能的最大批次大小。此函數也適用於巢狀字典,或可用於確定巢狀 tensordict 的批次大小。
- 參數:
input_dict (dictionary, optional) – 要用作資料來源的字典 (相容巢狀鍵)。
batch_size (iterable of int, optional) – tensordict 的批次大小。
device (torch.device or compatible type, optional) – TensorDict 的裝置。
batch_dims (int, optional) –
batch_dims
(即要考慮用於batch_size
的前導維度的數量)。與batch_size
互斥。請注意,這是 tensordict 的 __maximum__ 批次維度數量,可以容忍較小的數字。names (list of str, optional) – tensordict 的維度名稱。
範例
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} >>> print(from_dict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # nested dict: the nested TensorDict can have a different batch-size >>> # as long as its leading dims match. >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} >>> print(from_dict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # we can also use this to work out the batch sie of a tensordict >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) >>> print( from_dict(input_td)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)