• 文件 >
  • 使用 TensorDict 簡化 PyTorch 記憶體管理
捷徑

使用 TensorDict 簡化 PyTorch 記憶體管理

作者: Tom Begley

在本教學中,您將學習如何控制 TensorDict 的內容在記憶體中的儲存位置,可以將這些內容傳送到裝置,也可以使用記憶體映射。

裝置

當您建立 TensorDict 時,可以使用 device 關鍵字引數指定裝置。如果設定了 device,則 TensorDict 的所有條目都會放置在該裝置上。如果未設定 device,則 TensorDict 中的條目不一定要位於同一裝置上。

在這個範例中,我們使用 device="cuda:0" 實例化一個 TensorDict。當我們列印內容時,可以看到它們已移動到裝置上。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

如果 TensorDict 的裝置不是 None,則新的條目也會移動到裝置上。

>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

您可以使用 device 屬性檢查 TensorDict 的目前裝置。

>>> print(tensordict.device)
cuda:0

可以像 PyTorch 張量一樣,使用 TensorDict.cuda()TensorDict.device(device),將 TensorDict 的內容傳送到裝置,其中 device 是所需的裝置。

>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

TensorDict.device 方法需要傳遞有效的裝置作為引數。如果您想從 TensorDict 中移除裝置,以允許使用不同裝置的值,則應使用 TensorDict.clear_device 方法。

>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

記憶體映射張量

tensordict 提供了一個類別 MemoryMappedTensor,允許我們將張量的內容儲存在磁碟上,同時仍然支援快速索引和批次載入內容。請參閱 ImageNet 教學以取得實際範例。

若要將 TensorDict 轉換為記憶體映射張量的集合,請使用 TensorDict.memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()

print(tensordict)
TensorDict(
    fields={
        a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)

或者,您可以使用 TensorDict.memmap_like 方法。這將創建一個具有相同結構的新 TensorDict,其值為 MemoryMappedTensor,但不會將原始張量的內容複製到記憶體對應的張量中。這允許您創建記憶體對應的 TensorDict,然後慢慢填充它,因此通常應優先於 memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()

print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

預設情況下,TensorDict 的內容將儲存到磁碟上的臨時位置,但是,如果您想控制儲存位置,可以使用關鍵字參數 prefix="/path/to/root"

TensorDict 的內容儲存在一個目錄結構中,該結構模仿了 TensorDict 本身的結構。張量的內容儲存在 NumPy memmap 中,元資料儲存在相關聯的 PyTorch 儲存檔案中。例如,上述 TensorDict 的儲存方式如下

├── a.memmap
├── a.meta.pt
├── b
│ ├── c.memmap
│ ├── c.meta.pt
│ └── meta.pt
└── meta.pt

腳本的總執行時間: (0 分鐘 0.004 秒)

由 Sphinx-Gallery 產生

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深度教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源