注意
點擊這裡下載完整的範例程式碼
編寫自訂資料集、DataLoader 和轉換¶
建立於:2017 年 6 月 10 日 | 上次更新:2024 年 1 月 19 日 | 上次驗證:2024 年 11 月 05 日
解決任何機器學習問題的大部分精力都用於準備資料。 PyTorch 提供了許多工具,使資料載入變得容易,並希望使您的程式碼更具可讀性。在本教學中,我們將了解如何從非平凡資料集載入和預處理/擴增資料。
要執行此教學,請確保已安裝以下套件
scikit-image
:用於影像 IO 和轉換pandas
:用於更輕鬆的 CSV 解析
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
<contextlib.ExitStack object at 0x7ff7f616b310>
我們將要處理的資料集是面部姿勢。這表示臉部像這樣註釋
總共為每張臉註釋了 68 個不同的地標點。
注意
從這裡下載資料集,以便影像位於名為 'data/faces/' 的目錄中。此資料集實際上是透過將出色的 dlib 的姿勢估計應用於一些從 imagenet 標記為 'face' 的影像而產生的。
資料集附帶一個 .csv
檔案,其中包含如下所示的註釋
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
讓我們從 CSV 中取得單個影像名稱及其註釋,在本例中,索引編號為 65 的 person-7.jpg 作為範例。讀取它,將影像名稱儲存在 img_name
中,並將其註釋儲存在 (L, 2) 陣列 landmarks
中,其中 L 是該列中地標的數量。
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)
print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]
讓我們編寫一個簡單的輔助函數來顯示影像及其地標,並使用它來顯示範例。
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
landmarks)
plt.show()
data:image/s3,"s3://crabby-images/86270/86270cc8847c834dd55e5d766035190feae1a379" alt="data loading tutorial"
資料集類別¶
torch.utils.data.Dataset
是一個代表資料集的抽象類別。您的自訂資料集應繼承 Dataset
並覆寫以下方法
__len__
使得len(dataset)
傳回資料集的大小。__getitem__
以支援索引,使得dataset[i]
可用於取得第 \(i\) 個樣本。
讓我們為我們的臉部地標資料集建立一個資料集類別。我們將在 __init__
中讀取 csv,但將影像讀取保留給 __getitem__
。這是記憶體有效率的,因為所有影像不會一次儲存在記憶體中,而是根據需要讀取。
我們資料集的樣本將是一個字典 {'image': image, 'landmarks': landmarks}
。我們的資料集將採用一個可選引數 transform
,以便可以在樣本上應用任何所需的處理。我們將在下一節中看到 transform
的用處。
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Arguments:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
讓我們實例化此類別並迭代資料樣本。我們將列印前 4 個樣本的大小並顯示其地標。
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/')
fig = plt.figure()
for i, sample in enumerate(face_dataset):
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
data:image/s3,"s3://crabby-images/6048e/6048e67b4df3abeeafadd9b34bdb4c207459cdad" alt="Sample #0, Sample #1, Sample #2, Sample #3"
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)
轉換¶
我們可以從上面看到的一個問題是樣本的大小不相同。大多數神經網路期望具有固定大小的影像。因此,我們需要編寫一些預處理程式碼。讓我們建立三個轉換
Rescale
:縮放影像RandomCrop
:從影像中隨機裁剪。這是資料擴增。ToTensor
:將 numpy 影像轉換為 torch 影像(我們需要交換軸)。
我們將它們編寫為可呼叫類別,而不是簡單的函數,以便不必每次呼叫時都傳遞轉換的參數。為此,我們只需要實作 __call__
方法,如果需要,實作 __init__
方法。然後,我們可以像這樣使用轉換
tsfm = Transform(params)
transformed_sample = tsfm(sample)
請注意以下這些轉換如何必須同時應用於影像和地標。
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h + 1)
left = np.random.randint(0, w - new_w + 1)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
注意
在上面的範例中,RandomCrop 使用了外部函式庫的亂數產生器(在此例中為 Numpy 的 np.random.int)。這可能會導致 DataLoader 產生非預期的行為(請參閱這裡)。實務上,堅持使用 PyTorch 的亂數產生器會更安全,例如使用 torch.randint 來代替。
組合轉換 (Compose transforms)¶
現在,我們將轉換應用於一個樣本。
假設我們想要將影像的較短邊縮放到 256,然後從中隨機裁剪一個大小為 224 的正方形。也就是說,我們想要組合 Rescale
和 RandomCrop
轉換。torchvision.transforms.Compose
是一個簡單的可呼叫類別,允許我們執行此操作。
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()
data:image/s3,"s3://crabby-images/52988/529889e6578ac3f95c80b1e0ad1e70aad3b5f8a0" alt="Rescale, RandomCrop, Compose"
迭代資料集 (Iterating through the dataset)¶
讓我們將所有這些放在一起,以建立具有組合轉換的資料集。總結來說,每次對此資料集進行取樣時
會即時從檔案讀取影像
轉換會應用於讀取的影像
由於其中一個轉換是隨機的,因此在取樣時會進行資料擴增
我們可以像以前一樣,使用 for i in range
迴圈來迭代建立的資料集。
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i, sample in enumerate(transformed_dataset):
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
但是,透過使用簡單的 for
迴圈來迭代資料,我們會失去許多功能。特別是,我們會錯過
批次處理資料
隨機排序資料
使用
multiprocessing
workers 並行載入資料。
torch.utils.data.DataLoader
是一個迭代器,提供所有這些功能。下面使用的參數應該很清楚。一個值得注意的參數是 collate_fn
。您可以使用 collate_fn
指定需要如何對樣本進行批次處理。但是,預設的 collate 應該適用於大多數情況。
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=0)
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.
# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
data:image/s3,"s3://crabby-images/c8320/c8320658baa9c0cf923445e2ecd66ba4271b3946" alt="Batch from dataloader"
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
後記:torchvision (Afterword: torchvision)¶
在本教學中,我們已經了解如何編寫和使用資料集、轉換和資料載入器。torchvision
套件提供了一些常見的資料集和轉換。您甚至可能不需要編寫自定義類別。torchvision 中提供的更通用的資料集之一是 ImageFolder
。它假設影像以以下方式組織
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
其中 'ants'、'bees' 等是類別標籤。類似地,也可以使用對 PIL.Image
進行操作的通用轉換,例如 RandomHorizontalFlip
、Scale
。您可以使用這些來編寫資料載入器,如下所示
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
有關訓練程式碼的範例,請參閱 電腦視覺的遷移學習教學。
腳本的總執行時間: ( 0 分鐘 2.385 秒)