注意
前往結尾以下載完整範例程式碼。
使用 tensorclass 處理資料集¶
在本教學中,我們將示範如何使用 tensorclass 有效率且透明地載入和管理訓練流程中的資料。 本教學大量基於 PyTorch 快速入門教學,但修改為示範 tensorclass 的使用。 請參閱使用 TensorDict
的相關教學課程。
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
torchvision.datasets
模組包含許多方便的預先準備好的資料集。 在本教學中,我們將使用相對簡單的 FashionMNIST 資料集。 每張圖片都是一件衣服,目標是對圖片中的衣服類型進行分類 (例如「包包」、「運動鞋」等)。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:13, 361kB/s]
1%| | 229k/26.4M [00:00<00:38, 680kB/s]
3%|▎ | 885k/26.4M [00:00<00:10, 2.53MB/s]
7%|▋ | 1.93M/26.4M [00:00<00:05, 4.09MB/s]
25%|██▍ | 6.59M/26.4M [00:00<00:01, 15.4MB/s]
37%|███▋ | 9.70M/26.4M [00:00<00:01, 16.5MB/s]
60%|██████ | 15.9M/26.4M [00:01<00:00, 23.1MB/s]
84%|████████▍ | 22.2M/26.4M [00:01<00:00, 27.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.3MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|▏ | 65.5k/4.42M [00:00<00:12, 359kB/s]
5%|▌ | 229k/4.42M [00:00<00:06, 675kB/s]
21%|██ | 918k/4.42M [00:00<00:01, 2.09MB/s]
83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.21MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.02MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 69.9MB/s]
Tensorclass 是資料類別,它透過其內容公開專用的 tensor 方法,非常類似於 TensorDict
。 當您想要儲存的資料結構是固定且可預測時,它們是一個不錯的選擇。
除了指定內容外,我們還可以在定義類別時將相關邏輯封裝為自訂方法。 在這種情況下,我們將編寫一個 from_dataset
類別方法,該方法採用資料集作為輸入,並建立一個包含資料集資料的 tensorclass。 我們建立記憶體對應 tensor 以保存資料。 這將使我們能夠有效地從磁碟載入批次的轉換資料,而不是重複載入和轉換個別圖像。
@tensorclass
class FashionMNISTData:
images: torch.Tensor
targets: torch.Tensor
@classmethod
def from_dataset(cls, dataset, device=None):
data = cls(
images=MemoryMappedTensor.empty(
(len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
),
targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
batch_size=[len(dataset)],
device=device,
)
for i, (image, target) in enumerate(dataset):
data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
return data
我們將建立兩個 tensorclass,每個分別用於訓練和測試資料。 請注意,由於我們正在遍歷整個資料集、轉換並儲存到磁碟,因此這裡會產生一些額外負擔。
DataLoaders¶
我們將從 torchvision
提供的資料集以及我們的記憶體對應 TensorDict 建立 DataLoaders。
由於 TensorDict
實現了 __len__
和 __getitem__
(以及 __getitems__
),我們可以像 map-style 資料集一樣使用它,並直接從中建立 DataLoader
。 請注意,由於 TensorDict
已經可以處理批次索引,因此不需要排序規則,因此我們將恆等函數作為 collate_fn
傳遞。
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_tc = DataLoader( # noqa: TOR401
training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader( # noqa: TOR401
test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
模型¶
我們使用與 快速入門教學中相同的模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
優化參數¶
我們將使用隨機梯度下降和交叉熵損失來優化模型的參數。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
我們基於 tensorclass 的 DataLoader 的訓練迴圈非常相似,我們只是調整如何將資料解壓縮到 tensorclass 提供的更明確的基於屬性的檢索。 .contiguous()
方法載入儲存在 memmap tensor 中的資料。
def train_tc(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data.images.contiguous(), data.targets.contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_tc(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch.images.contiguous(), batch.targets.contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_tc:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.306911 [ 0/60000]
loss: 2.290546 [ 6400/60000]
loss: 2.267823 [12800/60000]
loss: 2.268296 [19200/60000]
loss: 2.239312 [25600/60000]
loss: 2.222285 [32000/60000]
loss: 2.223311 [38400/60000]
loss: 2.189771 [44800/60000]
loss: 2.196405 [51200/60000]
loss: 2.161141 [57600/60000]
Test Error:
Accuracy: 48.7%, Avg loss: 2.154104
Epoch 2
-------------------------
loss: 2.168520 [ 0/60000]
loss: 2.153506 [ 6400/60000]
loss: 2.092801 [12800/60000]
loss: 2.109224 [19200/60000]
loss: 2.055430 [25600/60000]
loss: 2.006157 [32000/60000]
loss: 2.025786 [38400/60000]
loss: 1.945255 [44800/60000]
loss: 1.956291 [51200/60000]
loss: 1.879478 [57600/60000]
Test Error:
Accuracy: 59.4%, Avg loss: 1.879949
Epoch 3
-------------------------
loss: 1.919434 [ 0/60000]
loss: 1.884455 [ 6400/60000]
loss: 1.763246 [12800/60000]
loss: 1.796725 [19200/60000]
loss: 1.699208 [25600/60000]
loss: 1.656230 [32000/60000]
loss: 1.662529 [38400/60000]
loss: 1.567700 [44800/60000]
loss: 1.596351 [51200/60000]
loss: 1.477721 [57600/60000]
Test Error:
Accuracy: 61.8%, Avg loss: 1.509531
Epoch 4
-------------------------
loss: 1.582884 [ 0/60000]
loss: 1.547020 [ 6400/60000]
loss: 1.392604 [12800/60000]
loss: 1.457310 [19200/60000]
loss: 1.349238 [25600/60000]
loss: 1.348273 [32000/60000]
loss: 1.347054 [38400/60000]
loss: 1.278697 [44800/60000]
loss: 1.319937 [51200/60000]
loss: 1.205683 [57600/60000]
Test Error:
Accuracy: 63.1%, Avg loss: 1.246076
Epoch 5
-------------------------
loss: 1.329433 [ 0/60000]
loss: 1.309014 [ 6400/60000]
loss: 1.141340 [12800/60000]
loss: 1.241332 [19200/60000]
loss: 1.118123 [25600/60000]
loss: 1.152766 [32000/60000]
loss: 1.158768 [38400/60000]
loss: 1.103974 [44800/60000]
loss: 1.149251 [51200/60000]
loss: 1.051577 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.084601
Tensorclass training done! time: 8.6816 s
Epoch 1
-------------------------
loss: 2.307089 [ 0/60000]
loss: 2.297157 [ 6400/60000]
loss: 2.284588 [12800/60000]
loss: 2.276901 [19200/60000]
loss: 2.242071 [25600/60000]
loss: 2.219066 [32000/60000]
loss: 2.220341 [38400/60000]
loss: 2.188907 [44800/60000]
loss: 2.183793 [51200/60000]
loss: 2.147927 [57600/60000]
Test Error:
Accuracy: 43.2%, Avg loss: 2.146838
Epoch 2
-------------------------
loss: 2.158795 [ 0/60000]
loss: 2.147090 [ 6400/60000]
loss: 2.094091 [12800/60000]
loss: 2.108519 [19200/60000]
loss: 2.044007 [25600/60000]
loss: 1.981364 [32000/60000]
loss: 2.003546 [38400/60000]
loss: 1.929537 [44800/60000]
loss: 1.926183 [51200/60000]
loss: 1.847433 [57600/60000]
Test Error:
Accuracy: 57.0%, Avg loss: 1.857306
Epoch 3
-------------------------
loss: 1.891647 [ 0/60000]
loss: 1.856303 [ 6400/60000]
loss: 1.746082 [12800/60000]
loss: 1.786176 [19200/60000]
loss: 1.664370 [25600/60000]
loss: 1.619249 [32000/60000]
loss: 1.633839 [38400/60000]
loss: 1.550666 [44800/60000]
loss: 1.565413 [51200/60000]
loss: 1.457971 [57600/60000]
Test Error:
Accuracy: 62.1%, Avg loss: 1.485636
Epoch 4
-------------------------
loss: 1.555910 [ 0/60000]
loss: 1.517564 [ 6400/60000]
loss: 1.375235 [12800/60000]
loss: 1.446436 [19200/60000]
loss: 1.323968 [25600/60000]
loss: 1.320687 [32000/60000]
loss: 1.333024 [38400/60000]
loss: 1.270114 [44800/60000]
loss: 1.298299 [51200/60000]
loss: 1.202938 [57600/60000]
Test Error:
Accuracy: 63.9%, Avg loss: 1.229178
Epoch 5
-------------------------
loss: 1.310207 [ 0/60000]
loss: 1.290053 [ 6400/60000]
loss: 1.126103 [12800/60000]
loss: 1.231878 [19200/60000]
loss: 1.110180 [25600/60000]
loss: 1.128504 [32000/60000]
loss: 1.154162 [38400/60000]
loss: 1.097552 [44800/60000]
loss: 1.130474 [51200/60000]
loss: 1.054337 [57600/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 1.072044
Training done! time: 34.6215 s
腳本的總運行時間:(1 分鐘 1.114 秒)