快捷方式

from_dict

class tensordict.from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None)

傳回從字典或另一個 TensorDict 建立的 TensorDict。

如果未指定 batch_size,則傳回可能的最大批次大小。

此函數也適用於巢狀字典,或可用於確定巢狀 tensordict 的批次大小。

參數:
  • input_dict (dictionary, optional) – 要用作資料來源的字典 (相容巢狀鍵)。

  • batch_size (iterable of int, optional) – tensordict 的批次大小。

  • device (torch.device or compatible type, optional) – TensorDict 的裝置。

  • batch_dims (int, optional) – batch_dims (即要考慮用於 batch_size 的前導維度的數量)。與 batch_size 互斥。請注意,這是 tensordict 的 __maximum__ 批次維度數量,可以容忍較小的數字。

  • names (list of str, optional) – tensordict 的維度名稱。

範例

>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(from_dict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(from_dict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(
from_dict(input_td))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得解答

檢視資源