• 文件 >
  • 使用 TensorDict 預先分配記憶體
快捷方式

使用 TensorDict 預先分配記憶體

作者: Tom Begley

在本教學中,您將學習如何利用 TensorDict 中的記憶體預先分配。

假設我們有一個函式會回傳一個 TensorDict

import torch
from tensordict.tensordict import TensorDict


def make_tensordict():
    return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])

也許我們想多次呼叫此函式,並使用結果來填充單一個 TensorDict

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])

for i in range(N):
    tensordict[i] = make_tensordict()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

因為我們指定了 tensordictbatch_size,所以在迴圈的第一次迭代中,我們用空的 tensor 填充 tensordict,這些 tensor 的第一個維度的大小為 N,其餘維度則由 make_tensordict 的回傳值決定。在上面的範例中,我們為鍵 "a" 預先分配一個大小為 torch.Size([10, 3]) 的零陣列,以及為鍵 "b" 預先分配一個大小為 torch.Size([10, 3, 4]) 的陣列。迴圈的後續迭代會就地寫入。因此,如果不是所有的值都被填滿,它們會得到預設值零。

讓我們透過逐步執行上面的迴圈來示範正在發生的事情。 我們首先初始化一個空的 TensorDict

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
    fields={
    },
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

在第一次迭代之後,tensordict 已經預先填充了鍵 "a""b" 的 tensor。 這些 tensor 包含零,除了我們已將隨機值分配給第一列之外。

random_tensordict = make_tensordict()
tensordict[0] = random_tensordict

assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

在後續的迭代中,我們會就地更新預先分配的 tensor。

a = tensordict["a"]
random_tensordict = make_tensordict()
tensordict[1] = random_tensordict

# the same tensor is stored under "a", but the values have been updated
assert tensordict["a"] is a
assert (tensordict[:2] != 0).all()

腳本的總執行時間: (0 分鐘 0.003 秒)

由 Sphinx-Gallery 產生的圖片集

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和高級開發者的深入教學課程

查看教學

資源

尋找開發資源並獲得問題解答

查看資源