注意
按一下 這裡以下載完整的範例程式碼
訓練器資料集範例¶
這是訓練範例中使用的資料集。它使用 PyTorch Lightning 函式庫。
import os.path
import tarfile
from typing import Callable, Optional
import fsspec
import numpy
import pytorch_lightning as pl
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets.folder import is_image_file
from tqdm import tqdm
這使用 torchvision 來定義一個資料集,我們稍後將在 Pytorch Lightning 資料模組中使用它。
class ImageFolderSamplesDataset(datasets.ImageFolder):
"""
ImageFolderSamplesDataset is a wrapper around ImageFolder that allows you to
limit the number of samples.
"""
def __init__(
self,
root: str,
transform: Optional[Callable[..., object]] = None,
num_samples: Optional[int] = None,
**kwargs: object,
) -> None:
"""
Args:
num_samples: optional. limits the size of the dataset
"""
super().__init__(root, transform=transform)
self.num_samples = num_samples
def __len__(self) -> int:
if self.num_samples is not None:
return self.num_samples
return super().__len__()
為了便於使用,我們定義了一個 Lightning 資料模組,以便我們可以在訓練器和其他需要載入資料的元件中重複使用它。
# pyre-fixme[13]: Attribute `test_ds` is never initialized.
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
class TinyImageNetDataModule(pl.LightningDataModule):
"""
TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
imagenet dataset.
"""
train_ds: ImageFolderSamplesDataset
val_ds: ImageFolderSamplesDataset
test_ds: ImageFolderSamplesDataset
def __init__(
self, data_dir: str, batch_size: int = 16, num_samples: Optional[int] = None
) -> None:
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_samples = num_samples
def setup(self, stage: Optional[str] = None) -> None:
# Setup data loader and transforms
img_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
self.train_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "train"),
transform=img_transform,
num_samples=self.num_samples,
)
self.val_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "val"),
transform=img_transform,
num_samples=self.num_samples,
)
self.test_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "test"),
transform=img_transform,
num_samples=self.num_samples,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_ds, batch_size=self.batch_size)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_ds, batch_size=self.batch_size)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_ds, batch_size=self.batch_size)
def teardown(self, stage: Optional[str] = None) -> None:
pass
為了在不同元件之間傳遞資料,我們使用 fsspec,它允許我們讀取/寫入雲端或本地檔案儲存體。
def download_data(remote_path: str, tmpdir: str) -> str:
"""
download_data downloads the training data from the specified remote path via
fsspec and places it in the tmpdir unextracted.
"""
if os.path.isdir(remote_path):
print("dataset path is a directory, using as is")
return remote_path
tar_path = os.path.join(tmpdir, "data.tar.gz")
print(f"downloading dataset from {remote_path} to {tar_path}...")
fs, _, rpaths = fsspec.get_fs_token_paths(remote_path)
assert len(rpaths) == 1, "must have single path"
fs.get(rpaths[0], tar_path)
data_path = os.path.join(tmpdir, "data")
print(f"extracting {tar_path} to {data_path}...")
with tarfile.open(tar_path, mode="r") as f:
f.extractall(data_path)
return data_path
def create_random_data(output_path: str, num_images: int = 250) -> None:
"""
Fills the given path with randomly generated 64x64 images.
This can be used for quick testing of the workflow of the model.
Does NOT pack the files into a tar, but does preprocess them.
"""
train_path = os.path.join(output_path, "train")
class1_train_path = os.path.join(train_path, "class1")
class2_train_path = os.path.join(train_path, "class2")
val_path = os.path.join(output_path, "val")
class1_val_path = os.path.join(val_path, "class1")
class2_val_path = os.path.join(val_path, "class2")
test_path = os.path.join(output_path, "test")
class1_test_path = os.path.join(test_path, "class1")
class2_test_path = os.path.join(test_path, "class2")
paths = [
class1_train_path,
class1_val_path,
class1_test_path,
class2_train_path,
class2_val_path,
class2_test_path,
]
for path in paths:
try:
os.makedirs(path)
except FileExistsError:
pass
for i in range(num_images):
pixels = numpy.random.rand(64, 64, 3) * 255
im = Image.fromarray(pixels.astype("uint8")).convert("RGB")
im.save(os.path.join(path, f"rand_image_{i}.jpeg"))
process_images(output_path)
def process_images(img_root: str) -> None:
print("transforming images...")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
transforms.ToPILImage(),
]
)
image_files = []
for root, _, fnames in os.walk(img_root):
for fname in fnames:
path = os.path.join(root, fname)
if not is_image_file(path):
continue
image_files.append(path)
for path in tqdm(image_files, miniters=int(len(image_files) / 2000)):
f = Image.open(path)
f = transform(f)
f.save(path)
# sphinx_gallery_thumbnail_path = '_static/img/gallery-lib.png'
腳本總執行時間: ( 0 分鐘 0.000 秒)