如何編寫您自己的 v2 轉換¶
本指南說明如何編寫與 torchvision 轉換 V2 API 相容的轉換。
只需建立一個 nn.Module
並覆寫 forward
方法¶
在大多數情況下,這就是您所需要的一切,只要您已經知道您的轉換所期望的輸入結構。 例如,如果您只是進行圖像分類,您的轉換通常會接受單個圖像作為輸入,或一個 (img, label)
輸入。 因此,您可以將您的 forward
方法硬式編碼為僅接受該輸入,例如:
class MyCustomTransform(torch.nn.Module):
def forward(self, img, label):
# Do some transformations
return new_img, new_label
注意
這表示如果您有一個與 V1 轉換(torchvision.transforms
中的那些轉換)相容的自訂轉換,它仍然可以與 V2 轉換一起使用,而無需任何變更!
我們將在下面使用一個典型的偵測案例更完整地說明這一點,在該案例中,我們的樣本只是圖像、邊界框和標籤
class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label
transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[224, 0, 224, 0],
[136, 0, 173, 0]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224))
out_label = 3
注意
在您的程式碼中使用 TVTensor 類別時,請務必熟悉此部分:我有一個 TVTensor,但現在我有一個 Tensor。 救命!
支援任意輸入結構¶
在上面的部分中,我們假設您已經知道您的輸入結構,並且您可以接受在您的程式碼中硬式編碼這個預期的結構。 如果您希望您的自訂轉換盡可能地靈活,這可能會有點限制。
內建的 Torchvision V2 轉換的一個主要特色是它們可以接受任意的輸入結構,並返回與輸出相同的結構(帶有轉換後的條目)。例如,轉換可以接受單個圖像,或 (img, label)
的元組,或任意嵌套的字典作為輸入。 以下是一個內建轉換 RandomHorizontalFlip
的範例
structured_input = {
"img": img,
"annotations": (bboxes, label),
"something that will be ignored": (1, "hello"),
"another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
The transformed bboxes are:
BoundingBoxes([[246, 10, 256, 20],
[186, 50, 206, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
基礎:覆寫 transform() 方法¶
為了在自定義轉換中支持任意輸入,您需要繼承 Transform
並覆寫 .transform() 方法(而不是 forward() 方法!)。以下是一個基本範例
class MyCustomTransform(v2.Transform):
def transform(self, inpt: Any, params: Dict[str, Any]):
if type(inpt) == torch.Tensor:
print(f"I'm transforming an image of shape {inpt.shape}")
return inpt + 1 # dummy transformation
elif isinstance(inpt, tv_tensors.BoundingBoxes):
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation
my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
The transformed bboxes are:
BoundingBoxes([[100, 110, 110, 120],
[150, 150, 170, 170]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
需要注意的重要一點是,當我們在 structured_input
上呼叫 my_custom_transform
時,輸入會被扁平化,然後每個單獨的部分會被傳遞給 transform()
。也就是說,transform()`
會接收輸入圖像,然後是邊界框等等。在 transform()
中,您可以根據輸入的類型決定如何轉換每個輸入。
如果您好奇為什麼另一個張量 (torch.arange()
) 沒有被傳遞給 transform()
,請參閱 此註解 以了解更多詳細信息。
進階:make_params()
方法¶
make_params()
方法在每次輸入呼叫 transform()
之前在內部被呼叫。這通常對於生成隨機參數值非常有用。在下面的範例中,我們使用它以 0.5 的機率隨機應用轉換
class MyRandomTransform(MyCustomTransform):
def __init__(self, p=0.5):
self.p = p
super().__init__()
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
apply_transform = (torch.rand(size=(1,)) < self.p).item()
params = dict(apply_transform=apply_transform)
return params
def transform(self, inpt: Any, params: Dict[str, Any]):
if not params["apply_transform"]:
print("Not transforming anything!")
return inpt
else:
return super().transform(inpt, params)
my_random_transform = MyRandomTransform()
torch.manual_seed(0)
_ = my_random_transform(structured_input) # transforms
_ = my_random_transform(structured_input) # doesn't transform
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
Not transforming anything!
Not transforming anything!
注意
重要的是,這種隨機參數生成需要在 make_params()
中進行,而不是在 transform()
中進行,以便對於給定的轉換呼叫,相同的 RNG 以相同的方式應用於所有輸入。如果我們在 transform()
中執行 RNG,我們可能會冒著轉換圖像但 *不* 轉換邊界框的風險。
make_params()
方法將所有輸入的列表作為參數(此列表中的每個元素稍後將傳遞到 transform()
)。您可以使用 flat_inputs
來例如,使用 query_chw()
或 query_size()
找出輸入的尺寸。
make_params()
應該返回一個 dict(或者實際上,任何你想要的東西),然後它將被傳遞給 transform()
。
腳本總執行時間:(0 分鐘 0.009 秒)