快捷方式

load_memmap

class tensordict.load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None)

從磁碟載入記憶體映射的 tensordict。

參數:
  • prefix (str資料夾路徑) – 儲存的 tensordict 應該從中提取的資料夾路徑。

  • device (torch.device等效, 可選) – 如果提供,資料將會非同步地轉換到該裝置。 支援 “meta” 裝置,在這種情況下,資料不會被載入,而是會建立一組空的 “meta” 張量。 這對於了解總模型大小和結構,而無需實際打開任何檔案非常有用。

  • non_blocking (bool, 可選) – 如果 True,在將張量載入到裝置後,將不會呼叫 synchronize。 預設為 False

  • out (TensorDictBase, 可選) – 可選的 tensordict,資料應寫入其中。

範例

>>> from tensordict import TensorDict, load_memmap
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

此方法也允許載入巢狀的 tensordicts。

範例

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

tensordict 也可以載入到 “meta” 裝置上,或者作為一個假的張量。

範例

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學文件

取得初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源