Transforms v2:端到端物件偵測/分割範例¶
原生支援物件偵測和分割任務:torchvision.transforms.v2
能夠聯合轉換影像、影片、邊界框和遮罩。
此範例展示了一個端到端的實例分割訓練案例,使用來自 torchvision.datasets
、torchvision.models
和 torchvision.transforms.v2
的 Torchvision 工具。此處涵蓋的所有內容都可以類似地應用於物件偵測或語意分割任務。
import pathlib
import torch
import torch.utils.data
from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2
torch.manual_seed(0)
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT = pathlib.Path("../assets") / "coco"
IMAGES_PATH = str(ROOT / "images")
ANNOTATIONS_PATH = str(ROOT / "instances.json")
from helpers import plot
資料集準備¶
首先,我們載入 CocoDetection
資料集,以查看它目前傳回的內容。
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH)
sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'list'>
type(target[0]) = <class 'dict'>
target[0].keys() = dict_keys(['segmentation', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])
Torchvision 資料集保留了資料結構和類型,如同資料集作者的意圖。因此,預設情況下,輸出結構可能並不總是與模型或轉換相容。
為了克服這個問題,我們可以使用 wrap_dataset_for_transforms_v2()
函數。對於 CocoDetection
,這會將目標結構更改為單一的列表字典
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'dict'>
target.keys() = dict_keys(['boxes', 'masks', 'labels'])
type(target['boxes']) = <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
type(target['labels']) = <class 'torch.Tensor'>
type(target['masks']) = <class 'torchvision.tv_tensors._mask.Mask'>
我們使用 target_keys
參數來指定我們感興趣的輸出種類。我們的資料集現在傳回一個目標,它是一個字典,其中值為 TVTensors(全部都是 torch.Tensor
的子類別)。我們已經從之前的輸出中刪除了所有不必要的鍵,但如果您需要任何原始鍵,例如「image_id」,您仍然可以要求它。
注意
如果您只想進行偵測,則不需要且不應該在 target_keys
中傳遞「masks」:如果樣本中存在遮罩,它們將被轉換,不必要地減慢您的轉換速度。
作為基準,讓我們看一下沒有轉換的樣本
plot([dataset[0], dataset[1]])

Transforms¶
現在讓我們定義我們的預處理轉換。所有轉換都知道如何在相關時處理影像、邊界框和遮罩。
轉換通常會作為資料集的 transforms
參數傳遞,以便它們可以利用來自 torch.utils.data.DataLoader
的多處理功能。
transforms = v2.Compose(
[
v2.ToImage(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
v2.ToDtype(torch.float32, scale=True),
]
)
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
以下有幾點值得注意:
我們正在將 PIL 圖像轉換為
Image
物件。 這並非絕對必要,但依賴 Tensors(此處:Tensor 子類別)通常會更快。我們正在呼叫
SanitizeBoundingBoxes
以確保我們移除退化的邊界框,以及它們對應的標籤和遮罩。SanitizeBoundingBoxes
至少應放置在檢測管線的結尾一次;如果使用了RandomIoUCrop
,則尤其重要。
讓我們看看在我們的擴增管線中樣本看起來如何
plot([dataset[0], dataset[1]])

我們可以看到圖像的顏色被扭曲、放大或縮小,並且被翻轉。 邊界框和遮罩也相應地轉換。 而且,事不宜遲,我們可以開始訓練。
資料載入和訓練迴圈¶
下面我們使用的是 Mask-RCNN,這是一個實例分割模型,但本教學課程中涵蓋的所有內容也適用於物件檢測和語義分割任務。
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
# We need a custom collation function here, since the object detection
# models expect a sequence of images and target dictionaries. The default
# collation function tries to torch.stack() the individual elements,
# which fails in general for object detection, because the number of bounding
# boxes varies between the images of the same batch.
collate_fn=lambda batch: tuple(zip(*batch)),
)
model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()
for imgs, targets in data_loader:
loss_dict = model(imgs, targets)
# Put your training logic here
print(f"{[img.shape for img in imgs] = }")
print(f"{[type(target) for target in targets] = }")
for name, loss_val in loss_dict.items():
print(f"{name:<20}{loss_val:.3f}")
[img.shape for img in imgs] = [torch.Size([3, 512, 512]), torch.Size([3, 409, 493])]
[type(target) for target in targets] = [<class 'dict'>, <class 'dict'>]
loss_classifier 4.721
loss_box_reg 0.006
loss_mask 0.734
loss_objectness 0.691
loss_rpn_box_reg 0.036
訓練參考資料¶
從那裡,您可以查看 torchvision references,您可以在其中找到我們用來訓練模型的實際訓練腳本。
免責聲明 我們參考文檔中的程式碼比您自己的用例需要的更複雜:這是因為我們支援不同的後端(PIL、張量、TVTensors)和不同的轉換命名空間(v1 和 v2)。 因此,不要害怕簡化並且只保留您需要的內容。
腳本的總執行時間: (0 分鐘 4.834 秒)