作者:Vasilis Vryniotis

TorchVision 有一個新的向後相容的 API,用於建構具有多權重支援的模型。 新的 API 允許在相同的模型變體上載入不同的預訓練權重,追蹤重要的元數據(例如分類標籤),並包含使用模型所需的預處理轉換。 在這篇部落格文章中,我們計劃回顧原型 API、展示其功能,並強調與現有 API 的主要差異。

我們希望在最終確定 API 之前獲得您的想法。 為了收集您的回饋,我們創建了一個 Github issue,您可以在其中發布您的想法、問題和評論。

目前 API 的限制

TorchVision 目前提供預訓練模型,這些模型可以作為遷移學習的起點,或按原樣用於電腦視覺應用程式。 實例化預訓練模型並進行預測的典型方法是

import torch

from PIL import Image
from torchvision import models as M
from torchvision.transforms import transforms as T


img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model
model = M.resnet50(pretrained=True)
model.eval()

# Step 2: Define and initialize the inference transforms
preprocess = T.Compose([
    T.Resize([256, ]),
    T.CenterCrop(224),
    T.PILToTensor(),
    T.ConvertImageDtype(torch.float),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)

# Step 4: Use the model and print the predicted category
class_id = prediction.argmax().item()
score = prediction[class_id].item()
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
    category_name = categories[class_id]
print(f"{category_name}: {100 * score}%")

上述方法存在一些限制

  1. 無法支援多個預訓練權重:由於 pretrained 變數是布林值,我們只能提供一組權重。 當我們顯著提高現有模型的準確性並且我們希望將這些改進提供給社群時,這是一個嚴重的限制。 它也阻止我們在不同的資料集上提供相同模型變體的預訓練權重。
  2. 缺少推論/預處理轉換:使用者在使用模型之前被迫定義必要的轉換。 推論轉換通常與訓練過程和用於估計權重的資料集相關聯。 這些轉換中的任何細微差異(例如插值、調整大小/裁剪大小等)都可能導致準確性大幅降低或模型無法使用。
  3. 缺少元資料:與權重相關的關鍵資訊對使用者不可用。 例如,需要查看外部來源和文件才能找到諸如 類別標籤、訓練配方、準確度指標等資訊。

新的 API 解決了上述限制,並減少了標準任務所需的樣板程式碼數量。

原型 API 概述

讓我們看看如何使用新的 API 實現與上述完全相同的結果

from PIL import Image
from torchvision.prototype import models as PM


img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model
weights = PM.ResNet50_Weights.IMAGENET1K_V1
model = PM.resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)

# Step 4: Use the model and print the predicted category
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}*%*")

正如我們所看到的,新的 API 消除了上述限制。 讓我們詳細探索新功能。

多權重支援

新 API 的核心是我們能夠為相同的模型變體定義多個不同的權重。 每種模型建構方法(例如 resnet50)都有一個相關的 Enum 類別(例如 ResNet50_Weights),其項目數量與可用的預訓練權重數量相同。 此外,每個 Enum 類別都有一個指向特定模型最佳可用權重的 DEFAULT 別名。 這允許想要始終使用最佳可用權重的使用者這樣做,而無需修改他們的程式碼。

這是一個使用不同權重初始化模型的範例

from torchvision.prototype.models import resnet50, ResNet50_Weights

# Legacy weights with accuracy 76.130%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# No weights - random initialization
model = resnet50(weights=None)

相關的元資料 & 預處理轉換

每個模型的權重都與元資料相關聯。 我們儲存的資訊類型取決於模型的任務(分類、偵測、分割等)。 典型的資訊包括訓練配方的連結、插值模式、諸如類別和驗證指標等資訊。 這些值可以透過 meta 屬性以程式設計方式存取

from torchvision.prototype.models import ResNet50_Weights

# Accessing a single record
size = ResNet50_Weights.IMAGENET1K_V2.meta["size"]

# Iterating the items of the meta-data dictionary
for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items():
    print(k, v)

此外,每個權重項目都與必要的預處理轉換相關聯。目前所有的預處理轉換都是可以進行 JIT 編譯的,並且可以通過 transforms 屬性來存取。在使用資料之前,需要初始化/構建這些轉換。這種延遲初始化方案是為了確保解決方案的記憶體效率。轉換的輸入可以是 PIL.Image 或是使用 torchvision.io 讀取的 Tensor

from torchvision.prototype.models import ResNet50_Weights

# Initializing preprocessing at standard 224x224 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()

# Initializing preprocessing at 400x400 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms(crop_size=400, resize_size=400)

# Once initialized the callable can accept the image data:
# img_preprocessed = preprocess(img)

將權重與其元資料和預處理相關聯將提高透明度、改進可重現性,並使記錄如何產生一組權重變得更容易。

通過名稱獲取權重

能夠直接將權重與其屬性(元資料、預處理 callable 等)鏈接起來是我們實作中使用 Enums 而不是 Strings 的原因。然而,對於僅能取得權重名稱的情況,我們提供了一種能夠將權重名稱鏈接到其 Enums 的方法。

from torchvision.prototype.models import get_weight

# Weights can be retrieved by name:
assert get_weight("ResNet50_Weights.IMAGENET1K_V1") == ResNet50_Weights.IMAGENET1K_V1
assert get_weight("ResNet50_Weights.IMAGENET1K_V2") == ResNet50_Weights.IMAGENET1K_V2

# Including using the DEFAULT alias:
assert get_weight("ResNet50_Weights.DEFAULT") == ResNet50_Weights.IMAGENET1K_V2

棄用

在新的 API 中,先前用於將權重加載到完整模型或其主幹的布林值 pretrainedpretrained_backbone 參數已被棄用。目前的實作完全向後兼容,因為它可以無縫地將舊參數映射到新參數。使用舊參數到新的建構器會發出以下棄用警告

>>> model = torchvision.prototype.models.resnet50(pretrained=True)
 UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead.
UserWarning:
Arguments other than a weight enum or `None` for 'weights' are deprecated.
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.

此外,建構器方法需要使用關鍵字參數。使用位置參數已被棄用,並且使用它們會發出以下警告

>>> model = torchvision.prototype.models.resnet50(None)
UserWarning:
Using 'weights' as positional parameter(s) is deprecated.
Please use keyword parameter(s) instead.

測試新的 API

遷移到新的 API 非常簡單。以下兩個 API 之間的調用都是等效的

# Using pretrained weights:
torchvision.prototype.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
torchvision.models.resnet50(pretrained=True)
torchvision.models.resnet50(True)

# Using no weights:
torchvision.prototype.models.resnet50(weights=None)
torchvision.models.resnet50(pretrained=False)
torchvision.models.resnet50(False)

請注意,原型功能僅在 TorchVision 的 nightly 版本上可用,因此要使用它,您需要按如下方式安裝它

conda install torchvision -c pytorch-nightly

有關安裝 nightly 的其他方法,請查看 PyTorch 的下載頁面。您也可以從最新的 main 分支從原始碼安裝 TorchVision;有關更多資訊,請查看我們的repo

使用新的 API 存取最先進的模型權重

如果您仍然不相信要試用新的 API,這裡還有一個理由。我們最近更新了我們的訓練配方,並從我們的許多模型中獲得了 SOTA 的準確性。可以通過新的 API 輕鬆存取改進的權重。以下是模型改進的快速概述

模型 舊的 Acc@1 新的 Acc@1
EfficientNet B1 78.642 79.838
MobileNetV3 Large 74.042 75.274
Quantized ResNet50 75.92 80.282
Quantized ResNeXt101 32x8d 78.986 82.574
RegNet X 400mf 72.834 74.864
RegNet X 800mf 75.212 77.522
RegNet X 1 6gf 77.04 79.668
RegNet X 3 2gf 78.364 81.198
RegNet X 8gf 79.344 81.682
RegNet X 16gf 80.058 82.72
RegNet X 32gf 80.622 83.018
RegNet Y 400mf 74.046 75.806
RegNet Y 800mf 76.42 78.838
RegNet Y 1 6gf 77.95 80.882
RegNet Y 3 2gf 78.948 81.984
RegNet Y 8gf 80.032 82.828
RegNet Y 16gf 80.424 82.89
RegNet Y 32gf 80.878 83.366
ResNet50 76.13 80.858
ResNet101 77.374 81.886
ResNet152 78.312 82.284
ResNeXt50 32x4d 77.618 81.198
ResNeXt101 32x8d 79.312 82.834
Wide ResNet50 2 78.468 81.602
Wide ResNet101 2 78.848 82.51

請花幾分鐘時間提供您對新 API 的反饋,因為這對於將其從原型畢業並將其包含在下一個版本中至關重要。您可以在專用的 Github Issue 上執行此操作。我們期待閱讀您的評論!