• 文件 >
  • 使用 tensorclass 處理資料集
快捷方式

使用 tensorclass 處理資料集

在本教學中,我們將示範如何使用 tensorclass 有效率且透明地載入和管理訓練流程中的資料。 本教學大量基於 PyTorch 快速入門教學,但修改為示範 tensorclass 的使用。 請參閱使用 TensorDict 的相關教學課程。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模組包含許多方便的預先準備好的資料集。 在本教學中,我們將使用相對簡單的 FashionMNIST 資料集。 每張圖片都是一件衣服,目標是對圖片中的衣服類型進行分類 (例如「包包」、「運動鞋」等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:13, 361kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 680kB/s]
  3%|▎         | 885k/26.4M [00:00<00:10, 2.53MB/s]
  7%|▋         | 1.93M/26.4M [00:00<00:05, 4.09MB/s]
 25%|██▍       | 6.59M/26.4M [00:00<00:01, 15.4MB/s]
 37%|███▋      | 9.70M/26.4M [00:00<00:01, 16.5MB/s]
 60%|██████    | 15.9M/26.4M [00:01<00:00, 23.1MB/s]
 84%|████████▍ | 22.2M/26.4M [00:01<00:00, 27.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 359kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 675kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.09MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.21MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.02MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 69.9MB/s]

Tensorclass 是資料類別,它透過其內容公開專用的 tensor 方法,非常類似於 TensorDict。 當您想要儲存的資料結構是固定且可預測時,它們是一個不錯的選擇。

除了指定內容外,我們還可以在定義類別時將相關邏輯封裝為自訂方法。 在這種情況下,我們將編寫一個 from_dataset 類別方法,該方法採用資料集作為輸入,並建立一個包含資料集資料的 tensorclass。 我們建立記憶體對應 tensor 以保存資料。 這將使我們能夠有效地從磁碟載入批次的轉換資料,而不是重複載入和轉換個別圖像。

@tensorclass
class FashionMNISTData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset, device=None):
        data = cls(
            images=MemoryMappedTensor.empty(
                (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
            ),
            targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
            batch_size=[len(dataset)],
            device=device,
        )
        for i, (image, target) in enumerate(dataset):
            data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
        return data

我們將建立兩個 tensorclass,每個分別用於訓練和測試資料。 請注意,由於我們正在遍歷整個資料集、轉換並儲存到磁碟,因此這裡會產生一些額外負擔。

training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)

DataLoaders

我們將從 torchvision 提供的資料集以及我們的記憶體對應 TensorDict 建立 DataLoaders。

由於 TensorDict 實現了 __len____getitem__(以及 __getitems__),我們可以像 map-style 資料集一樣使用它,並直接從中建立 DataLoader。 請注意,由於 TensorDict 已經可以處理批次索引,因此不需要排序規則,因此我們將恆等函數作為 collate_fn 傳遞。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_tc = DataLoader(  # noqa: TOR401
    training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(  # noqa: TOR401
    test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

模型

我們使用與 快速入門教學中相同的模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

優化參數

我們將使用隨機梯度下降和交叉熵損失來優化模型的參數。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我們基於 tensorclass 的 DataLoader 的訓練迴圈非常相似,我們只是調整如何將資料解壓縮到 tensorclass 提供的更明確的基於屬性的檢索。 .contiguous() 方法載入儲存在 memmap tensor 中的資料。

def train_tc(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data.images.contiguous(), data.targets.contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_tc(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch.images.contiguous(), batch.targets.contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_tc:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
    test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
    images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.306911 [    0/60000]
loss: 2.290546 [ 6400/60000]
loss: 2.267823 [12800/60000]
loss: 2.268296 [19200/60000]
loss: 2.239312 [25600/60000]
loss: 2.222285 [32000/60000]
loss: 2.223311 [38400/60000]
loss: 2.189771 [44800/60000]
loss: 2.196405 [51200/60000]
loss: 2.161141 [57600/60000]
Test Error:
 Accuracy: 48.7%, Avg loss: 2.154104

Epoch 2
-------------------------
loss: 2.168520 [    0/60000]
loss: 2.153506 [ 6400/60000]
loss: 2.092801 [12800/60000]
loss: 2.109224 [19200/60000]
loss: 2.055430 [25600/60000]
loss: 2.006157 [32000/60000]
loss: 2.025786 [38400/60000]
loss: 1.945255 [44800/60000]
loss: 1.956291 [51200/60000]
loss: 1.879478 [57600/60000]
Test Error:
 Accuracy: 59.4%, Avg loss: 1.879949

Epoch 3
-------------------------
loss: 1.919434 [    0/60000]
loss: 1.884455 [ 6400/60000]
loss: 1.763246 [12800/60000]
loss: 1.796725 [19200/60000]
loss: 1.699208 [25600/60000]
loss: 1.656230 [32000/60000]
loss: 1.662529 [38400/60000]
loss: 1.567700 [44800/60000]
loss: 1.596351 [51200/60000]
loss: 1.477721 [57600/60000]
Test Error:
 Accuracy: 61.8%, Avg loss: 1.509531

Epoch 4
-------------------------
loss: 1.582884 [    0/60000]
loss: 1.547020 [ 6400/60000]
loss: 1.392604 [12800/60000]
loss: 1.457310 [19200/60000]
loss: 1.349238 [25600/60000]
loss: 1.348273 [32000/60000]
loss: 1.347054 [38400/60000]
loss: 1.278697 [44800/60000]
loss: 1.319937 [51200/60000]
loss: 1.205683 [57600/60000]
Test Error:
 Accuracy: 63.1%, Avg loss: 1.246076

Epoch 5
-------------------------
loss: 1.329433 [    0/60000]
loss: 1.309014 [ 6400/60000]
loss: 1.141340 [12800/60000]
loss: 1.241332 [19200/60000]
loss: 1.118123 [25600/60000]
loss: 1.152766 [32000/60000]
loss: 1.158768 [38400/60000]
loss: 1.103974 [44800/60000]
loss: 1.149251 [51200/60000]
loss: 1.051577 [57600/60000]
Test Error:
 Accuracy: 64.4%, Avg loss: 1.084601

Tensorclass training done! time:  8.6816 s
Epoch 1
-------------------------
loss: 2.307089 [    0/60000]
loss: 2.297157 [ 6400/60000]
loss: 2.284588 [12800/60000]
loss: 2.276901 [19200/60000]
loss: 2.242071 [25600/60000]
loss: 2.219066 [32000/60000]
loss: 2.220341 [38400/60000]
loss: 2.188907 [44800/60000]
loss: 2.183793 [51200/60000]
loss: 2.147927 [57600/60000]
Test Error:
 Accuracy: 43.2%, Avg loss: 2.146838

Epoch 2
-------------------------
loss: 2.158795 [    0/60000]
loss: 2.147090 [ 6400/60000]
loss: 2.094091 [12800/60000]
loss: 2.108519 [19200/60000]
loss: 2.044007 [25600/60000]
loss: 1.981364 [32000/60000]
loss: 2.003546 [38400/60000]
loss: 1.929537 [44800/60000]
loss: 1.926183 [51200/60000]
loss: 1.847433 [57600/60000]
Test Error:
 Accuracy: 57.0%, Avg loss: 1.857306

Epoch 3
-------------------------
loss: 1.891647 [    0/60000]
loss: 1.856303 [ 6400/60000]
loss: 1.746082 [12800/60000]
loss: 1.786176 [19200/60000]
loss: 1.664370 [25600/60000]
loss: 1.619249 [32000/60000]
loss: 1.633839 [38400/60000]
loss: 1.550666 [44800/60000]
loss: 1.565413 [51200/60000]
loss: 1.457971 [57600/60000]
Test Error:
 Accuracy: 62.1%, Avg loss: 1.485636

Epoch 4
-------------------------
loss: 1.555910 [    0/60000]
loss: 1.517564 [ 6400/60000]
loss: 1.375235 [12800/60000]
loss: 1.446436 [19200/60000]
loss: 1.323968 [25600/60000]
loss: 1.320687 [32000/60000]
loss: 1.333024 [38400/60000]
loss: 1.270114 [44800/60000]
loss: 1.298299 [51200/60000]
loss: 1.202938 [57600/60000]
Test Error:
 Accuracy: 63.9%, Avg loss: 1.229178

Epoch 5
-------------------------
loss: 1.310207 [    0/60000]
loss: 1.290053 [ 6400/60000]
loss: 1.126103 [12800/60000]
loss: 1.231878 [19200/60000]
loss: 1.110180 [25600/60000]
loss: 1.128504 [32000/60000]
loss: 1.154162 [38400/60000]
loss: 1.097552 [44800/60000]
loss: 1.130474 [51200/60000]
loss: 1.054337 [57600/60000]
Test Error:
 Accuracy: 65.1%, Avg loss: 1.072044

Training done! time:  34.6215 s

腳本的總運行時間:(1 分鐘 1.114 秒)

由 Sphinx-Gallery 產生圖庫

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深度教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源