模型與預訓練權重¶
torchvision.models
子套件包含用於解決不同任務的模型定義,包括:圖像分類、像素級語意分割、物件偵測、實例分割、人物關鍵點偵測、影片分類和光流。
關於預訓練權重的一般資訊¶
TorchVision 為每個提供的架構提供預訓練權重,使用 PyTorch torch.hub
。 實例化預訓練模型會將其權重下載到快取目錄。 可以使用 TORCH_HOME 環境變數設定此目錄。 有關詳細訊息,請參閱 torch.hub.load_state_dict_from_url()
。
注意
此函式庫中提供的預訓練模型可能具有其自身的授權條款或源自用於訓練的資料集的條款和條件。 您有責任確定您是否有權限將模型用於您的使用案例。
注意
保證向後相容性,可將序列化的 state_dict
載入到使用舊版 PyTorch 建立的模型。 相反地,載入整個已儲存的模型或序列化的 ScriptModules
(使用舊版 PyTorch 序列化)可能無法保留歷史行為。 請參閱以下文件
初始化預訓練模型¶
從 v0.13 開始,TorchVision 提供了一個新的 Multi-weight support API,用於將不同的權重載入到現有的模型建構器方法中
from torchvision.models import resnet50, ResNet50_Weights
# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)
# Strings are also supported
resnet50(weights="IMAGENET1K_V2")
# No weights - random initialization
resnet50(weights=None)
遷移到新的 API 非常簡單。 2 個 API 之間的以下方法呼叫都是等效的
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated
請注意,pretrained
參數現在已棄用,使用它會發出警告,並且將在 v0.15 中移除。
使用預訓練模型¶
在使用預訓練模型之前,必須預先處理影像(調整大小,使用正確的解析度/插補,應用推論轉換,重新調整值等等)。 沒有標準方法可以做到這一點,因為它取決於給定模型的訓練方式。 它可能因模型系列、變體甚至權重版本而異。 使用正確的預處理方法至關重要,否則可能會導致準確性降低或輸出不正確。
每個預訓練模型的推論轉換的所有必要資訊都在其權重文件中提供。 為了簡化推論,TorchVision 將必要的預處理轉換綁定到每個模型權重中。 這些可以透過 weight.transforms
屬性存取
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
有些模型使用具有不同訓練和評估行為的模組,例如批次正規化 (batch normalization)。若要切換這些模式,請適當地使用 model.train()
或 model.eval()
。詳情請參閱 train()
或 eval()
。
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
# Set model to eval mode
model.eval()
列出並檢索可用的模型¶
從 v0.14 開始,TorchVision 提供了一種新的機制,允許依據模型和權重的名稱來列出和檢索它們。以下是一些關於如何使用它們的範例:
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
以下是可用的公開函式,用於檢索模型及其對應的權重:
|
取得模型名稱和配置,並傳回一個實例化的模型。 |
|
傳回與給定模型相關聯的權重列舉類別 (enum class)。 |
|
透過完整名稱取得權重列舉值 (enum value)。 |
|
傳回已註冊模型名稱的列表。 |
使用來自 Hub 的模型¶
大多數預訓練模型可以透過 PyTorch Hub 直接存取,而無需安裝 TorchVision。
import torch
# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
您也可以透過 PyTorch Hub 檢索特定模型的所有可用權重,方法如下:
import torch
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])
唯一的例外是 torchvision.models.detection
中包含的檢測模型。這些模型需要安裝 TorchVision,因為它們依賴於自定義的 C++ 運算符。
分類¶
以下分類模型可用,無論是否具有預訓練權重:
以下是如何使用預訓練圖像分類模型的範例:
from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
預訓練模型輸出的類別可以在 weights.meta["categories"]
中找到。
所有可用的分類權重表¶
準確度是在 ImageNet-1K 上使用單一裁剪 (single crops) 報告的。
權重 (Weight) |
Acc@1 |
Acc@5 |
參數 (Params) |
GFLOPS |
配方 (Recipe) |
---|---|---|---|---|---|
56.522 |
79.066 |
61.1M |
0.71 |
||
84.062 |
96.87 |
88.6M |
15.36 |
||
84.414 |
96.976 |
197.8M |
34.36 |
||
83.616 |
96.65 |
50.2M |
8.68 |
||
82.52 |
96.146 |
28.6M |
4.46 |
||
74.434 |
91.972 |
8.0M |
2.83 |
||
77.138 |
93.56 |
28.7M |
7.73 |
||
75.6 |
92.806 |
14.1M |
3.36 |
||
76.896 |
93.37 |
20.0M |
4.29 |
||
77.692 |
93.532 |
5.3M |
0.39 |
||
78.642 |
94.186 |
7.8M |
0.69 |
||
79.838 |
94.934 |
7.8M |
0.69 |
||
80.608 |
95.31 |
9.1M |
1.09 |
||
82.008 |
96.054 |
12.2M |
1.83 |
||
83.384 |
96.594 |
19.3M |
4.39 |
||
83.444 |
96.628 |
30.4M |
10.27 |
||
84.008 |
96.916 |
43.0M |
19.07 |
||
84.122 |
96.908 |
66.3M |
37.75 |
||
85.808 |
97.788 |
118.5M |
56.08 |
||
85.112 |
97.156 |
54.1M |
24.58 |
||
84.228 |
96.878 |
21.5M |
8.37 |
||
69.778 |
89.53 |
6.6M |
1.5 |
||
77.294 |
93.45 |
27.2M |
5.71 |
||
67.734 |
87.49 |
2.2M |
0.1 |
||
71.18 |
90.496 |
3.2M |
0.21 |
||
73.456 |
91.51 |
4.4M |
0.31 |
||
76.506 |
93.522 |
6.3M |
0.53 |
||
83.7 |
96.722 |
30.9M |
5.56 |
||
71.878 |
90.286 |
3.5M |
0.3 |
||
72.154 |
90.822 |
3.5M |
0.3 |
||
74.042 |
91.34 |
5.5M |
0.22 |
||
75.274 |
92.566 |
5.5M |
0.22 |
||
67.668 |
87.402 |
2.5M |
0.06 |
||
80.058 |
94.944 |
54.3M |
15.94 |
||
82.716 |
96.196 |
54.3M |
15.94 |
||
77.04 |
93.44 |
9.2M |
1.6 |
||
79.668 |
94.922 |
9.2M |
1.6 |
||
80.622 |
95.248 |
107.8M |
31.74 |
||
83.014 |
96.288 |
107.8M |
31.74 |
||
78.364 |
93.992 |
15.3M |
3.18 |
||
81.196 |
95.43 |
15.3M |
3.18 |
||
72.834 |
90.95 |
5.5M |
0.41 |
||
74.864 |
92.322 |
5.5M |
0.41 |
||
75.212 |
92.348 |
7.3M |
0.8 |
||
77.522 |
93.826 |
7.3M |
0.8 |
||
79.344 |
94.686 |
39.6M |
8 |
||
81.682 |
95.678 |
39.6M |
8 |
||
88.228 |
98.682 |
644.8M |
374.57 |
||
86.068 |
97.844 |
644.8M |
127.52 |
||
80.424 |
95.24 |
83.6M |
15.91 |
||
82.886 |
96.328 |
83.6M |
15.91 |
||
86.012 |
98.054 |
83.6M |
46.73 |
||
83.976 |
97.244 |
83.6M |
15.91 |
||
77.95 |
93.966 |
11.2M |
1.61 |
||
80.876 |
95.444 |
11.2M |
1.61 |
||
80.878 |
95.34 |
145.0M |
32.28 |
||
83.368 |
96.498 |
145.0M |
32.28 |
||
86.838 |
98.362 |
145.0M |
94.83 |
||
84.622 |
97.48 |
145.0M |
32.28 |
||
78.948 |
94.576 |
19.4M |
3.18 |
||
81.982 |
95.972 |
19.4M |
3.18 |
||
74.046 |
91.716 |
4.3M |
0.4 |
||
75.804 |
92.742 |
4.3M |
0.4 |
||
76.42 |
93.136 |
6.4M |
0.83 |
||
78.828 |
94.502 |
6.4M |
0.83 |
||
80.032 |
95.048 |
39.4M |
8.47 |
||
82.828 |
96.33 |
39.4M |
8.47 |
||
79.312 |
94.526 |
88.8M |
16.41 |
||
82.834 |
96.228 |
88.8M |
16.41 |
||
83.246 |
96.454 |
83.5M |
15.46 |
||
77.618 |
93.698 |
25.0M |
4.23 |
||
81.198 |
95.34 |
25.0M |
4.23 |
||
77.374 |
93.546 |
44.5M |
7.8 |
||
81.886 |
95.78 |
44.5M |
7.8 |
||
78.312 |
94.046 |
60.2M |
11.51 |
||
82.284 |
96.002 |
60.2M |
11.51 |
||
69.758 |
89.078 |
11.7M |
1.81 |
||
73.314 |
91.42 |
21.8M |
3.66 |
||
76.13 |
92.862 |
25.6M |
4.09 |
||
80.858 |
95.434 |
25.6M |
4.09 |
||
60.552 |
81.746 |
1.4M |
0.04 |
||
69.362 |
88.316 |
2.3M |
0.14 |
||
72.996 |
91.086 |
3.5M |
0.3 |
||
76.23 |
93.006 |
7.4M |
0.58 |
||
58.092 |
80.42 |
1.2M |
0.82 |
||
58.178 |
80.624 |
1.2M |
0.35 |
||
83.582 |
96.64 |
87.8M |
15.43 |
||
83.196 |
96.36 |
49.6M |
8.74 |
||
81.474 |
95.776 |
28.3M |
4.49 |
||
84.112 |
96.864 |
87.9M |
20.32 |
||
83.712 |
96.816 |
49.7M |
11.55 |
||
82.072 |
96.132 |
28.4M |
5.94 |
||
70.37 |
89.81 |
132.9M |
7.61 |
||
69.02 |
88.628 |
132.9M |
7.61 |
||
71.586 |
90.374 |
133.1M |
11.31 |
||
69.928 |
89.246 |
133.0M |
11.31 |
||
73.36 |
91.516 |
138.4M |
15.47 |
||
71.592 |
90.382 |
138.4M |
15.47 |
||
nan |
nan |
138.4M |
15.47 |
||
74.218 |
91.842 |
143.7M |
19.63 |
||
72.376 |
90.876 |
143.7M |
19.63 |
||
81.072 |
95.318 |
86.6M |
17.56 |
||
85.304 |
97.65 |
86.9M |
55.48 |
||
81.886 |
96.18 |
86.6M |
17.56 |
||
75.912 |
92.466 |
88.2M |
4.41 |
||
88.552 |
98.694 |
633.5M |
1016.72 |
||
85.708 |
97.73 |
632.0M |
167.29 |
||
79.662 |
94.638 |
304.3M |
61.55 |
||
88.064 |
98.512 |
305.2M |
361.99 |
||
85.146 |
97.422 |
304.3M |
61.55 |
||
76.972 |
93.07 |
306.5M |
15.38 |
||
78.848 |
94.284 |
126.9M |
22.75 |
||
82.51 |
96.02 |
126.9M |
22.75 |
||
78.468 |
94.086 |
68.9M |
11.4 |
||
81.602 |
95.758 |
68.9M |
11.4 |
量化模型¶
以下架構提供對 INT8 量化模型的支援,無論是否具有預訓練權重
以下是如何使用預訓練量化影像分類模型的範例
from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
預訓練模型輸出的類別可以在 weights.meta["categories"]
中找到。
所有可用量化分類權重表¶
準確度是在 ImageNet-1K 上使用單一裁剪 (single crops) 報告的。
權重 (Weight) |
Acc@1 |
Acc@5 |
參數 (Params) |
GIPS |
配方 (Recipe) |
---|---|---|---|---|---|
69.826 |
89.404 |
6.6M |
1.5 |
||
77.176 |
93.354 |
27.2M |
5.71 |
||
71.658 |
90.15 |
3.5M |
0.3 |
||
73.004 |
90.858 |
5.5M |
0.22 |
||
78.986 |
94.48 |
88.8M |
16.41 |
||
82.574 |
96.132 |
88.8M |
16.41 |
||
82.898 |
96.326 |
83.5M |
15.46 |
||
69.494 |
88.882 |
11.7M |
1.81 |
||
75.92 |
92.814 |
25.6M |
4.09 |
||
80.282 |
94.976 |
25.6M |
4.09 |
||
57.972 |
79.78 |
1.4M |
0.04 |
||
68.36 |
87.582 |
2.3M |
0.14 |
||
72.052 |
90.7 |
3.5M |
0.3 |
||
75.354 |
92.488 |
7.4M |
0.58 |
語義分割¶
警告
分割模組處於 Beta 階段,不保證向後相容性。
以下語義分割模型可用,無論是否具有預訓練權重
以下是如何使用預訓練語義分割模型的範例
from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
img = decode_image("gallery/assets/dog1.jpg")
# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
預訓練模型輸出的類別可在 weights.meta["categories"]
中找到。 模型的輸出格式如 語義分割模型 中所示。
所有可用的語義分割權重表¶
所有模型都在 COCO val2017 的子集上進行評估,針對 Pascal VOC 資料集中存在的 20 個類別
權重 (Weight) |
平均 IoU |
像素準確率 |
參數 (Params) |
GFLOPS |
配方 (Recipe) |
---|---|---|---|---|---|
|
60.3 |
91.2 |
11.0M |
10.45 |
|
67.4 |
92.4 |
61.0M |
258.74 |
||
66.4 |
92.4 |
42.0M |
178.72 |
||
63.7 |
91.9 |
54.3M |
232.74 |
||
60.5 |
91.4 |
35.3M |
152.72 |
||
57.9 |
91.2 |
3.2M |
2.09 |
物件偵測、實例分割和人物關鍵點偵測¶
用於偵測、實例分割和關鍵點偵測的預訓練模型使用 torchvision 中的分類模型進行初始化。 這些模型預期為 Tensor[C, H, W]
的列表。 有關更多資訊,請查看模型的建構函式。
警告
偵測模組處於 Beta 階段,不保證向後相容性。
物件偵測¶
以下物件偵測模型可用,無論是否具有預訓練權重
以下是如何使用預訓練物件偵測模型的範例
from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]
# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
預訓練模型輸出的類別可在 weights.meta["categories"]
中找到。 有關如何繪製模型邊界框的詳細資訊,您可以參考 實例分割模型。
所有可用的物件偵測權重表¶
Box MAP 在 COCO val2017 上報告
權重 (Weight) |
Box MAP |
參數 (Params) |
GFLOPS |
配方 (Recipe) |
---|---|---|---|---|
39.2 |
32.3M |
128.21 |
||
22.8 |
19.4M |
0.72 |
||
32.8 |
19.4M |
4.49 |
||
46.7 |
43.7M |
280.37 |
||
37 |
41.8M |
134.38 |
||
41.5 |
38.2M |
152.24 |
||
36.4 |
34.0M |
151.54 |
||
25.1 |
35.6M |
34.86 |
||
21.3 |
3.4M |
0.58 |
實例分割¶
以下實例分割模型可用,無論是否具有預訓練權重
有關如何繪製模型遮罩的詳細資訊,您可以參考 實例分割模型。
所有可用的實例分割權重表¶
Box 和 Mask MAP 在 COCO val2017 上報告
權重 (Weight) |
Box MAP |
Mask MAP |
參數 (Params) |
GFLOPS |
配方 (Recipe) |
---|---|---|---|---|---|
47.4 |
41.8 |
46.4M |
333.58 |
||
37.9 |
34.6 |
44.4M |
134.38 |
關鍵點偵測¶
以下人物關鍵點偵測模型可用,無論是否具有預訓練權重
預訓練模型輸出的類別可在 weights.meta["keypoint_names"]
中找到。 有關如何繪製模型邊界框的詳細資訊,您可以參考 可視化關鍵點。
所有可用的關鍵點偵測權重表¶
Box 和 Keypoint MAP 在 COCO val2017 上報告
權重 (Weight) |
Box MAP |
Keypoint MAP |
參數 (Params) |
GFLOPS |
配方 (Recipe) |
---|---|---|---|---|---|
50.6 |
61.1 |
59.1M |
133.92 |
||
54.6 |
65 |
59.1M |
137.42 |
影片分類¶
警告
影片模組處於 Beta 階段,不保證向後相容性。
以下影片分類模型可用,無論是否具有預訓練權重
這裡是如何使用預訓練視訊分類模型的一個範例
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
預訓練模型輸出的類別可以在 weights.meta["categories"]
中找到。
所有可用的視訊分類權重表¶
準確率是使用單一裁剪(clip length 16)在 Kinetics-400 上報告的
權重 (Weight) |
Acc@1 |
Acc@5 |
參數 (Params) |
GFLOPS |
配方 (Recipe) |
---|---|---|---|---|---|
63.96 |
84.13 |
11.7M |
43.34 |
||
78.477 |
93.582 |
36.6M |
70.6 |
||
80.757 |
94.665 |
34.5M |
64.22 |
||
67.463 |
86.175 |
31.5M |
40.52 |
||
63.2 |
83.479 |
33.4M |
40.7 |
||
68.368 |
88.05 |
8.3M |
17.98 |
||
79.427 |
94.386 |
88.0M |
140.67 |
||
81.643 |
95.574 |
88.0M |
140.67 |
||
79.521 |
94.158 |
49.8M |
82.84 |
||
77.715 |
93.519 |
28.2M |
43.88 |
光流 (Optical Flow)¶
以下是可用的光流模型,無論是否有預訓練