捷徑

如何使用 CutMix 和 MixUp

注意

Colab 上嘗試,或前往結尾下載完整的範例程式碼。

CutMixMixUp 是常見的擴增策略,可以提高分類準確度。

這些轉換與 Torchvision 的其他轉換略有不同,因為它們預期輸入的是樣本的批次,而不是單個影像。在本範例中,我們將說明如何使用它們:在 DataLoader 之後,或作為整理函數的一部分。

import torch
from torchvision.datasets import FakeData
from torchvision.transforms import v2


NUM_CLASSES = 100

預處理流程

我們將使用一個簡單但典型的影像分類流程

preproc = v2.Compose([
    v2.PILToTensor(),
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # typically from ImageNet
])

dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)

img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.float32, img.shape = torch.Size([3, 224, 224]), label = 67

需要注意的重要一點是,CutMix 和 MixUp 都不是此預處理流程的一部分。我們稍後會在定義 DataLoader 後再添加它們。作為複習,如果我們不使用 CutMix 或 MixUp,DataLoader 和訓練迴圈會是這樣:

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labels in dataloader:
    print(f"{images.shape = }, {labels.shape = }")
    print(labels.dtype)
    # <rest of the training loop here>
    break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
torch.int64

在哪裡使用 MixUp 和 CutMix

在 DataLoader 之後

現在讓我們添加 CutMix 和 MixUp。最簡單的方法是在 DataLoader 之後立即執行此操作:DataLoader 已經為我們批次處理了影像和標籤,而這正是這些轉換所期望的輸入。

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

cutmix = v2.CutMix(num_classes=NUM_CLASSES)
mixup = v2.MixUp(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
    print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
    images, labels = cutmix_or_mixup(images, labels)
    print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")

    # <rest of the training loop here>
    break
Before CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
After CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])

請注意標籤是如何轉換的:我們從形狀為 (batch_size,) 的批次標籤轉換為形狀為 (batch_size, num_classes) 的張量。轉換後的標籤仍然可以直接傳遞給損失函數,例如 torch.nn.functional.cross_entropy()

作為 collation function 的一部分

在 DataLoader 之後傳遞 transforms 是使用 CutMix 和 MixUp 最簡單的方法,但一個缺點是它沒有利用 DataLoader 的多進程處理。為此,我們可以將這些 transforms 作為 collation function 的一部分傳遞(請參閱 PyTorch 文件以了解更多關於 collation 的信息)。

from torch.utils.data import default_collate


def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))


dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)

for images, labels in dataloader:
    print(f"{images.shape = }, {labels.shape = }")
    # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
    # <rest of the training loop here>
    break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])

非標準輸入格式

到目前為止,我們使用了一個典型的樣本結構,其中我們傳遞 (images, labels) 作為輸入。 MixUp 和 CutMix 預設情況下可以神奇地與大多數常見的樣本結構一起使用:第二個參數是張量標籤的元組,或者帶有 "label[s]" 鍵的字典。 有關更多詳細信息,請查看 labels_getter 參數的文檔。

如果您的樣本具有不同的結構,您仍然可以通過將可呼叫物件傳遞給 labels_getter 參數來使用 CutMix 和 MixUp。 例如

batch = {
    "imgs": torch.rand(4, 3, 224, 224),
    "target": {
        "classes": torch.randint(0, NUM_CLASSES, size=(4,)),
        "some_other_key": "this is going to be passed-through"
    }
}


def labels_getter(batch):
    return batch["target"]["classes"]


out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")
out['imgs'].shape = torch.Size([4, 3, 224, 224]), out['target']['classes'].shape = torch.Size([4, 100])

腳本總運行時間: (0 分鐘 0.176 秒)

由 Sphinx-Gallery 生成的圖庫

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源