• 文件 >
  • 使用 TensorDict 處理資料集
快捷方式

使用 TensorDict 處理資料集

在本教學中,我們將示範如何使用 TensorDict 來有效率且透明地載入和管理訓練流程中的資料。本教學主要基於 PyTorch 快速入門教學,但修改為示範如何使用 TensorDict

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模組包含許多方便的預先準備的資料集。在本教學中,我們將使用相對簡單的 FashionMNIST 資料集。每張圖片都是一件衣服,目標是將圖片中的衣服類型進行分類(例如「包包」、「運動鞋」等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

我們將建立兩個 tensordict,分別用於訓練和測試資料。我們建立記憶體映射張量來保存資料。這將允許我們有效率地從磁碟載入轉換後的資料批次,而不是重複載入和轉換個別圖片。

首先,我們建立 MemoryMappedTensor 容器。

training_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(training_data), *training_data[0][0].squeeze().shape),
            dtype=torch.float32,
        ),
        "targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
    },
    batch_size=[len(training_data)],
    device=device,
)
test_data_td = TensorDict(
    {
        "images": MemoryMappedTensor.empty(
            (len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
        ),
        "targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
    },
    batch_size=[len(test_data)],
    device=device,
)

然後,我們可以迭代資料來填充記憶體映射張量。這需要一些時間,但預先執行轉換將節省稍後訓練期間的重複工作。

for i, (img, label) in enumerate(training_data):
    training_data_td[i] = TensorDict({"images": img, "targets": label}, [])

for i, (img, label) in enumerate(test_data):
    test_data_td[i] = TensorDict({"images": img, "targets": label}, [])

DataLoaders

我們將從 torchvision 提供的資料集以及我們的記憶體映射 TensorDict 建立 DataLoaders。

由於 TensorDict 實作了 __len____getitem__(以及 __getitems__),我們可以像使用 map-style Dataset 一樣使用它,並直接從它建立 DataLoader。請注意,由於 TensorDict 已經可以處理批次索引,因此不需要整理,因此我們將恆等函數作為 collate_fn 傳遞。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_td = DataLoader(  # noqa: TOR401
    training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader(  # noqa: TOR401
    test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)

模型

我們使用與 快速入門教學中相同的模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

最佳化參數

我們將使用隨機梯度下降和交叉熵損失來最佳化模型的參數。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我們基於 TensorDict 的 DataLoader 的訓練迴圈非常相似,我們只是調整如何解包資料,以適應 TensorDict 提供的更明確的基於鍵的檢索。.contiguous() 方法載入儲存在 memmap 張量中的資料。

def train_td(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data["images"].contiguous(), data["targets"].contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_td(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch["images"].contiguous(), batch["targets"].contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_td:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
    test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
    fields={
        images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
        targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.293629 [    0/60000]
loss: 2.283033 [ 6400/60000]
loss: 2.256484 [12800/60000]
loss: 2.256073 [19200/60000]
loss: 2.243077 [25600/60000]
loss: 2.200774 [32000/60000]
loss: 2.217680 [38400/60000]
loss: 2.179253 [44800/60000]
loss: 2.169554 [51200/60000]
loss: 2.139296 [57600/60000]
Test Error:
 Accuracy: 44.9%, Avg loss: 2.136127

Epoch 2
-------------------------
loss: 2.144999 [    0/60000]
loss: 2.140007 [ 6400/60000]
loss: 2.071116 [12800/60000]
loss: 2.092876 [19200/60000]
loss: 2.046267 [25600/60000]
loss: 1.972650 [32000/60000]
loss: 2.009594 [38400/60000]
loss: 1.925318 [44800/60000]
loss: 1.919691 [51200/60000]
loss: 1.849112 [57600/60000]
Test Error:
 Accuracy: 59.9%, Avg loss: 1.853141

Epoch 3
-------------------------
loss: 1.884452 [    0/60000]
loss: 1.857988 [ 6400/60000]
loss: 1.732184 [12800/60000]
loss: 1.780074 [19200/60000]
loss: 1.670048 [25600/60000]
loss: 1.618595 [32000/60000]
loss: 1.644744 [38400/60000]
loss: 1.546432 [44800/60000]
loss: 1.562902 [51200/60000]
loss: 1.465493 [57600/60000]
Test Error:
 Accuracy: 61.8%, Avg loss: 1.488911

Epoch 4
-------------------------
loss: 1.553364 [    0/60000]
loss: 1.522695 [ 6400/60000]
loss: 1.372123 [12800/60000]
loss: 1.449625 [19200/60000]
loss: 1.332421 [25600/60000]
loss: 1.329863 [32000/60000]
loss: 1.347445 [38400/60000]
loss: 1.271675 [44800/60000]
loss: 1.303136 [51200/60000]
loss: 1.211887 [57600/60000]
Test Error:
 Accuracy: 63.5%, Avg loss: 1.239002

Epoch 5
-------------------------
loss: 1.314920 [    0/60000]
loss: 1.296802 [ 6400/60000]
loss: 1.132275 [12800/60000]
loss: 1.240491 [19200/60000]
loss: 1.118523 [25600/60000]
loss: 1.144619 [32000/60000]
loss: 1.168913 [38400/60000]
loss: 1.102033 [44800/60000]
loss: 1.141350 [51200/60000]
loss: 1.064545 [57600/60000]
Test Error:
 Accuracy: 64.8%, Avg loss: 1.083915

TensorDict training done! time:  8.6170 s
Epoch 1
-------------------------
loss: 2.290858 [    0/60000]
loss: 2.286275 [ 6400/60000]
loss: 2.267298 [12800/60000]
loss: 2.265695 [19200/60000]
loss: 2.254766 [25600/60000]
loss: 2.215138 [32000/60000]
loss: 2.229925 [38400/60000]
loss: 2.191596 [44800/60000]
loss: 2.188072 [51200/60000]
loss: 2.166435 [57600/60000]
Test Error:
 Accuracy: 43.9%, Avg loss: 2.155688

Epoch 2
-------------------------
loss: 2.158259 [    0/60000]
loss: 2.155049 [ 6400/60000]
loss: 2.093325 [12800/60000]
loss: 2.112539 [19200/60000]
loss: 2.072581 [25600/60000]
loss: 1.997474 [32000/60000]
loss: 2.036950 [38400/60000]
loss: 1.950100 [44800/60000]
loss: 1.957567 [51200/60000]
loss: 1.895268 [57600/60000]
Test Error:
 Accuracy: 55.7%, Avg loss: 1.885085

Epoch 3
-------------------------
loss: 1.911978 [    0/60000]
loss: 1.891730 [ 6400/60000]
loss: 1.763546 [12800/60000]
loss: 1.807613 [19200/60000]
loss: 1.707991 [25600/60000]
loss: 1.642813 [32000/60000]
loss: 1.683022 [38400/60000]
loss: 1.573177 [44800/60000]
loss: 1.606999 [51200/60000]
loss: 1.511539 [57600/60000]
Test Error:
 Accuracy: 61.3%, Avg loss: 1.516370

Epoch 4
-------------------------
loss: 1.578484 [    0/60000]
loss: 1.552396 [ 6400/60000]
loss: 1.389732 [12800/60000]
loss: 1.469645 [19200/60000]
loss: 1.355540 [25600/60000]
loss: 1.340846 [32000/60000]
loss: 1.369824 [38400/60000]
loss: 1.286581 [44800/60000]
loss: 1.330394 [51200/60000]
loss: 1.239374 [57600/60000]
Test Error:
 Accuracy: 63.4%, Avg loss: 1.253140

Epoch 5
-------------------------
loss: 1.325610 [    0/60000]
loss: 1.314812 [ 6400/60000]
loss: 1.139354 [12800/60000]
loss: 1.251529 [19200/60000]
loss: 1.128213 [25600/60000]
loss: 1.147657 [32000/60000]
loss: 1.176389 [38400/60000]
loss: 1.109371 [44800/60000]
loss: 1.156932 [51200/60000]
loss: 1.078585 [57600/60000]
Test Error:
 Accuracy: 64.4%, Avg loss: 1.089056

Training done! time:  34.5017 s

腳本的總運行時間:(0 分鐘 55.621 秒)

由 Sphinx-Gallery 產生

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適用於初學者和進階開發者的深度教學

檢視教學

資源

尋找開發資源並取得解答

檢視資源