• 文件 >
  • 模型與預訓練權重
捷徑

模型與預訓練權重

torchvision.models 子套件包含用於解決不同任務的模型定義,包括:圖像分類、像素級語意分割、物件偵測、實例分割、人物關鍵點偵測、影片分類和光流。

關於預訓練權重的一般資訊

TorchVision 為每個提供的架構提供預訓練權重,使用 PyTorch torch.hub。 實例化預訓練模型會將其權重下載到快取目錄。 可以使用 TORCH_HOME 環境變數設定此目錄。 有關詳細訊息,請參閱 torch.hub.load_state_dict_from_url()

注意

此函式庫中提供的預訓練模型可能具有其自身的授權條款或源自用於訓練的資料集的條款和條件。 您有責任確定您是否有權限將模型用於您的使用案例。

注意

保證向後相容性,可將序列化的 state_dict 載入到使用舊版 PyTorch 建立的模型。 相反地,載入整個已儲存的模型或序列化的 ScriptModules(使用舊版 PyTorch 序列化)可能無法保留歷史行為。 請參閱以下文件

初始化預訓練模型

從 v0.13 開始,TorchVision 提供了一個新的 Multi-weight support API,用於將不同的權重載入到現有的模型建構器方法中

from torchvision.models import resnet50, ResNet50_Weights

# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None)

遷移到新的 API 非常簡單。 2 個 API 之間的以下方法呼叫都是等效的

from torchvision.models import resnet50, ResNet50_Weights

# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True)  # deprecated
resnet50(True)  # deprecated

# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False)  # deprecated
resnet50(False)  # deprecated

請注意,pretrained 參數現在已棄用,使用它會發出警告,並且將在 v0.15 中移除。

使用預訓練模型

在使用預訓練模型之前,必須預先處理影像(調整大小,使用正確的解析度/插補,應用推論轉換,重新調整值等等)。 沒有標準方法可以做到這一點,因為它取決於給定模型的訓練方式。 它可能因模型系列、變體甚至權重版本而異。 使用正確的預處理方法至關重要,否則可能會導致準確性降低或輸出不正確。

每個預訓練模型的推論轉換的所有必要資訊都在其權重文件中提供。 為了簡化推論,TorchVision 將必要的預處理轉換綁定到每個模型權重中。 這些可以透過 weight.transforms 屬性存取

# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)

有些模型使用具有不同訓練和評估行為的模組,例如批次正規化 (batch normalization)。若要切換這些模式,請適當地使用 model.train()model.eval()。詳情請參閱 train()eval()

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Set model to eval mode
model.eval()

列出並檢索可用的模型

從 v0.14 開始,TorchVision 提供了一種新的機制,允許依據模型和權重的名稱來列出和檢索它們。以下是一些關於如何使用它們的範例:

# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)

# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")

# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT

weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights

weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2

以下是可用的公開函式,用於檢索模型及其對應的權重:

get_model(name, **config)

取得模型名稱和配置,並傳回一個實例化的模型。

get_model_weights(name)

傳回與給定模型相關聯的權重列舉類別 (enum class)。

get_weight(name)

透過完整名稱取得權重列舉值 (enum value)。

list_models([module, include, exclude])

傳回已註冊模型名稱的列表。

使用來自 Hub 的模型

大多數預訓練模型可以透過 PyTorch Hub 直接存取,而無需安裝 TorchVision。

import torch

# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")

# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)

您也可以透過 PyTorch Hub 檢索特定模型的所有可用權重,方法如下:

import torch

weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])

唯一的例外是 torchvision.models.detection 中包含的檢測模型。這些模型需要安裝 TorchVision,因為它們依賴於自定義的 C++ 運算符。

分類

以下分類模型可用,無論是否具有預訓練權重:


以下是如何使用預訓練圖像分類模型的範例:

from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

預訓練模型輸出的類別可以在 weights.meta["categories"] 中找到。

所有可用的分類權重表

準確度是在 ImageNet-1K 上使用單一裁剪 (single crops) 報告的。

權重 (Weight)

Acc@1

Acc@5

參數 (Params)

GFLOPS

配方 (Recipe)

AlexNet_Weights.IMAGENET1K_V1

56.522

79.066

61.1M

0.71

連結 (link)

ConvNeXt_Base_Weights.IMAGENET1K_V1

84.062

96.87

88.6M

15.36

連結 (link)

ConvNeXt_Large_Weights.IMAGENET1K_V1

84.414

96.976

197.8M

34.36

連結 (link)

ConvNeXt_Small_Weights.IMAGENET1K_V1

83.616

96.65

50.2M

8.68

連結 (link)

ConvNeXt_Tiny_Weights.IMAGENET1K_V1

82.52

96.146

28.6M

4.46

連結 (link)

DenseNet121_Weights.IMAGENET1K_V1

74.434

91.972

8.0M

2.83

連結 (link)

DenseNet161_Weights.IMAGENET1K_V1

77.138

93.56

28.7M

7.73

連結 (link)

DenseNet169_Weights.IMAGENET1K_V1

75.6

92.806

14.1M

3.36

連結 (link)

DenseNet201_Weights.IMAGENET1K_V1

76.896

93.37

20.0M

4.29

連結 (link)

EfficientNet_B0_Weights.IMAGENET1K_V1

77.692

93.532

5.3M

0.39

連結 (link)

EfficientNet_B1_Weights.IMAGENET1K_V1

78.642

94.186

7.8M

0.69

連結 (link)

EfficientNet_B1_Weights.IMAGENET1K_V2

79.838

94.934

7.8M

0.69

連結 (link)

EfficientNet_B2_Weights.IMAGENET1K_V1

80.608

95.31

9.1M

1.09

連結 (link)

EfficientNet_B3_Weights.IMAGENET1K_V1

82.008

96.054

12.2M

1.83

連結 (link)

EfficientNet_B4_Weights.IMAGENET1K_V1

83.384

96.594

19.3M

4.39

連結 (link)

EfficientNet_B5_Weights.IMAGENET1K_V1

83.444

96.628

30.4M

10.27

連結 (link)

EfficientNet_B6_Weights.IMAGENET1K_V1

84.008

96.916

43.0M

19.07

連結 (link)

EfficientNet_B7_Weights.IMAGENET1K_V1

84.122

96.908

66.3M

37.75

連結 (link)

EfficientNet_V2_L_Weights.IMAGENET1K_V1

85.808

97.788

118.5M

56.08

連結 (link)

EfficientNet_V2_M_Weights.IMAGENET1K_V1

85.112

97.156

54.1M

24.58

連結 (link)

EfficientNet_V2_S_Weights.IMAGENET1K_V1

84.228

96.878

21.5M

8.37

連結 (link)

GoogLeNet_Weights.IMAGENET1K_V1

69.778

89.53

6.6M

1.5

連結 (link)

Inception_V3_Weights.IMAGENET1K_V1

77.294

93.45

27.2M

5.71

連結 (link)

MNASNet0_5_Weights.IMAGENET1K_V1

67.734

87.49

2.2M

0.1

連結 (link)

MNASNet0_75_Weights.IMAGENET1K_V1

71.18

90.496

3.2M

0.21

連結 (link)

MNASNet1_0_Weights.IMAGENET1K_V1

73.456

91.51

4.4M

0.31

連結 (link)

MNASNet1_3_Weights.IMAGENET1K_V1

76.506

93.522

6.3M

0.53

連結 (link)

MaxVit_T_Weights.IMAGENET1K_V1

83.7

96.722

30.9M

5.56

連結 (link)

MobileNet_V2_Weights.IMAGENET1K_V1

71.878

90.286

3.5M

0.3

連結 (link)

MobileNet_V2_Weights.IMAGENET1K_V2

72.154

90.822

3.5M

0.3

連結 (link)

MobileNet_V3_Large_Weights.IMAGENET1K_V1

74.042

91.34

5.5M

0.22

連結 (link)

MobileNet_V3_Large_Weights.IMAGENET1K_V2

75.274

92.566

5.5M

0.22

連結 (link)

MobileNet_V3_Small_Weights.IMAGENET1K_V1

67.668

87.402

2.5M

0.06

連結 (link)

RegNet_X_16GF_Weights.IMAGENET1K_V1

80.058

94.944

54.3M

15.94

連結 (link)

RegNet_X_16GF_Weights.IMAGENET1K_V2

82.716

96.196

54.3M

15.94

連結 (link)

RegNet_X_1_6GF_Weights.IMAGENET1K_V1

77.04

93.44

9.2M

1.6

連結 (link)

RegNet_X_1_6GF_Weights.IMAGENET1K_V2

79.668

94.922

9.2M

1.6

連結 (link)

RegNet_X_32GF_Weights.IMAGENET1K_V1

80.622

95.248

107.8M

31.74

連結 (link)

RegNet_X_32GF_Weights.IMAGENET1K_V2

83.014

96.288

107.8M

31.74

連結 (link)

RegNet_X_3_2GF_Weights.IMAGENET1K_V1

78.364

93.992

15.3M

3.18

連結 (link)

RegNet_X_3_2GF_Weights.IMAGENET1K_V2

81.196

95.43

15.3M

3.18

連結 (link)

RegNet_X_400MF_Weights.IMAGENET1K_V1

72.834

90.95

5.5M

0.41

連結 (link)

RegNet_X_400MF_Weights.IMAGENET1K_V2

74.864

92.322

5.5M

0.41

連結 (link)

RegNet_X_800MF_Weights.IMAGENET1K_V1

75.212

92.348

7.3M

0.8

連結 (link)

RegNet_X_800MF_Weights.IMAGENET1K_V2

77.522

93.826

7.3M

0.8

連結 (link)

RegNet_X_8GF_Weights.IMAGENET1K_V1

79.344

94.686

39.6M

8

連結 (link)

RegNet_X_8GF_Weights.IMAGENET1K_V2

81.682

95.678

39.6M

8

連結 (link)

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1

88.228

98.682

644.8M

374.57

連結 (link)

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

86.068

97.844

644.8M

127.52

連結 (link)

RegNet_Y_16GF_Weights.IMAGENET1K_V1

80.424

95.24

83.6M

15.91

連結 (link)

RegNet_Y_16GF_Weights.IMAGENET1K_V2

82.886

96.328

83.6M

15.91

連結 (link)

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.012

98.054

83.6M

46.73

連結 (link)

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

83.976

97.244

83.6M

15.91

連結 (link)

RegNet_Y_1_6GF_Weights.IMAGENET1K_V1

77.95

93.966

11.2M

1.61

連結 (link)

RegNet_Y_1_6GF_Weights.IMAGENET1K_V2

80.876

95.444

11.2M

1.61

連結 (link)

RegNet_Y_32GF_Weights.IMAGENET1K_V1

80.878

95.34

145.0M

32.28

連結 (link)

RegNet_Y_32GF_Weights.IMAGENET1K_V2

83.368

96.498

145.0M

32.28

連結 (link)

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.838

98.362

145.0M

94.83

連結 (link)

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

84.622

97.48

145.0M

32.28

連結 (link)

RegNet_Y_3_2GF_Weights.IMAGENET1K_V1

78.948

94.576

19.4M

3.18

連結 (link)

RegNet_Y_3_2GF_Weights.IMAGENET1K_V2

81.982

95.972

19.4M

3.18

連結 (link)

RegNet_Y_400MF_Weights.IMAGENET1K_V1

74.046

91.716

4.3M

0.4

連結 (link)

RegNet_Y_400MF_Weights.IMAGENET1K_V2

75.804

92.742

4.3M

0.4

連結 (link)

RegNet_Y_800MF_Weights.IMAGENET1K_V1

76.42

93.136

6.4M

0.83

連結 (link)

RegNet_Y_800MF_Weights.IMAGENET1K_V2

78.828

94.502

6.4M

0.83

連結 (link)

RegNet_Y_8GF_Weights.IMAGENET1K_V1

80.032

95.048

39.4M

8.47

連結 (link)

RegNet_Y_8GF_Weights.IMAGENET1K_V2

82.828

96.33

39.4M

8.47

連結 (link)

ResNeXt101_32X8D_Weights.IMAGENET1K_V1

79.312

94.526

88.8M

16.41

連結 (link)

ResNeXt101_32X8D_Weights.IMAGENET1K_V2

82.834

96.228

88.8M

16.41

連結 (link)

ResNeXt101_64X4D_Weights.IMAGENET1K_V1

83.246

96.454

83.5M

15.46

連結 (link)

ResNeXt50_32X4D_Weights.IMAGENET1K_V1

77.618

93.698

25.0M

4.23

連結 (link)

ResNeXt50_32X4D_Weights.IMAGENET1K_V2

81.198

95.34

25.0M

4.23

連結 (link)

ResNet101_Weights.IMAGENET1K_V1

77.374

93.546

44.5M

7.8

連結 (link)

ResNet101_Weights.IMAGENET1K_V2

81.886

95.78

44.5M

7.8

連結 (link)

ResNet152_Weights.IMAGENET1K_V1

78.312

94.046

60.2M

11.51

連結 (link)

ResNet152_Weights.IMAGENET1K_V2

82.284

96.002

60.2M

11.51

連結 (link)

ResNet18_Weights.IMAGENET1K_V1

69.758

89.078

11.7M

1.81

連結 (link)

ResNet34_Weights.IMAGENET1K_V1

73.314

91.42

21.8M

3.66

連結 (link)

ResNet50_Weights.IMAGENET1K_V1

76.13

92.862

25.6M

4.09

連結 (link)

ResNet50_Weights.IMAGENET1K_V2

80.858

95.434

25.6M

4.09

連結 (link)

ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1

60.552

81.746

1.4M

0.04

連結 (link)

ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1

69.362

88.316

2.3M

0.14

連結 (link)

ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1

72.996

91.086

3.5M

0.3

連結 (link)

ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1

76.23

93.006

7.4M

0.58

連結 (link)

SqueezeNet1_0_Weights.IMAGENET1K_V1

58.092

80.42

1.2M

0.82

連結 (link)

SqueezeNet1_1_Weights.IMAGENET1K_V1

58.178

80.624

1.2M

0.35

連結 (link)

Swin_B_Weights.IMAGENET1K_V1

83.582

96.64

87.8M

15.43

連結 (link)

Swin_S_Weights.IMAGENET1K_V1

83.196

96.36

49.6M

8.74

連結 (link)

Swin_T_Weights.IMAGENET1K_V1

81.474

95.776

28.3M

4.49

連結 (link)

Swin_V2_B_Weights.IMAGENET1K_V1

84.112

96.864

87.9M

20.32

連結 (link)

Swin_V2_S_Weights.IMAGENET1K_V1

83.712

96.816

49.7M

11.55

連結 (link)

Swin_V2_T_Weights.IMAGENET1K_V1

82.072

96.132

28.4M

5.94

連結 (link)

VGG11_BN_Weights.IMAGENET1K_V1

70.37

89.81

132.9M

7.61

連結 (link)

VGG11_Weights.IMAGENET1K_V1

69.02

88.628

132.9M

7.61

連結 (link)

VGG13_BN_Weights.IMAGENET1K_V1

71.586

90.374

133.1M

11.31

連結 (link)

VGG13_Weights.IMAGENET1K_V1

69.928

89.246

133.0M

11.31

連結 (link)

VGG16_BN_Weights.IMAGENET1K_V1

73.36

91.516

138.4M

15.47

連結 (link)

VGG16_Weights.IMAGENET1K_V1

71.592

90.382

138.4M

15.47

連結 (link)

VGG16_Weights.IMAGENET1K_FEATURES

nan

nan

138.4M

15.47

連結 (link)

VGG19_BN_Weights.IMAGENET1K_V1

74.218

91.842

143.7M

19.63

連結 (link)

VGG19_Weights.IMAGENET1K_V1

72.376

90.876

143.7M

19.63

連結 (link)

ViT_B_16_Weights.IMAGENET1K_V1

81.072

95.318

86.6M

17.56

連結 (link)

ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1

85.304

97.65

86.9M

55.48

連結 (link)

ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

81.886

96.18

86.6M

17.56

連結 (link)

ViT_B_32_Weights.IMAGENET1K_V1

75.912

92.466

88.2M

4.41

連結 (link)

ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1

88.552

98.694

633.5M

1016.72

連結 (link)

ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.708

97.73

632.0M

167.29

連結 (link)

ViT_L_16_Weights.IMAGENET1K_V1

79.662

94.638

304.3M

61.55

連結 (link)

ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1

88.064

98.512

305.2M

361.99

連結 (link)

ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.146

97.422

304.3M

61.55

連結 (link)

ViT_L_32_Weights.IMAGENET1K_V1

76.972

93.07

306.5M

15.38

連結 (link)

Wide_ResNet101_2_Weights.IMAGENET1K_V1

78.848

94.284

126.9M

22.75

連結 (link)

Wide_ResNet101_2_Weights.IMAGENET1K_V2

82.51

96.02

126.9M

22.75

連結 (link)

Wide_ResNet50_2_Weights.IMAGENET1K_V1

78.468

94.086

68.9M

11.4

連結 (link)

Wide_ResNet50_2_Weights.IMAGENET1K_V2

81.602

95.758

68.9M

11.4

連結 (link)

量化模型

以下架構提供對 INT8 量化模型的支援,無論是否具有預訓練權重


以下是如何使用預訓練量化影像分類模型的範例

from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

預訓練模型輸出的類別可以在 weights.meta["categories"] 中找到。

所有可用量化分類權重表

準確度是在 ImageNet-1K 上使用單一裁剪 (single crops) 報告的。

權重 (Weight)

Acc@1

Acc@5

參數 (Params)

GIPS

配方 (Recipe)

GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1

69.826

89.404

6.6M

1.5

連結 (link)

Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1

77.176

93.354

27.2M

5.71

連結 (link)

MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1

71.658

90.15

3.5M

0.3

連結 (link)

MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1

73.004

90.858

5.5M

0.22

連結 (link)

ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1

78.986

94.48

88.8M

16.41

連結 (link)

ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V2

82.574

96.132

88.8M

16.41

連結 (link)

ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1

82.898

96.326

83.5M

15.46

連結 (link)

ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1

69.494

88.882

11.7M

1.81

連結 (link)

ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1

75.92

92.814

25.6M

4.09

連結 (link)

ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2

80.282

94.976

25.6M

4.09

連結 (link)

ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1

57.972

79.78

1.4M

0.04

連結 (link)

ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1

68.36

87.582

2.3M

0.14

連結 (link)

ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1

72.052

90.7

3.5M

0.3

連結 (link)

ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1

75.354

92.488

7.4M

0.58

連結 (link)

語義分割

警告

分割模組處於 Beta 階段,不保證向後相容性。

以下語義分割模型可用,無論是否具有預訓練權重


以下是如何使用預訓練語義分割模型的範例

from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = decode_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()

預訓練模型輸出的類別可在 weights.meta["categories"] 中找到。 模型的輸出格式如 語義分割模型 中所示。

所有可用的語義分割權重表

所有模型都在 COCO val2017 的子集上進行評估,針對 Pascal VOC 資料集中存在的 20 個類別

權重 (Weight)

平均 IoU

像素準確率

參數 (Params)

GFLOPS

配方 (Recipe)

DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1

60.3

91.2

11.0M

10.45

連結 (link)

DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1

67.4

92.4

61.0M

258.74

連結 (link)

DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1

66.4

92.4

42.0M

178.72

連結 (link)

FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1

63.7

91.9

54.3M

232.74

連結 (link)

FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1

60.5

91.4

35.3M

152.72

連結 (link)

LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1

57.9

91.2

3.2M

2.09

連結 (link)

物件偵測、實例分割和人物關鍵點偵測

用於偵測、實例分割和關鍵點偵測的預訓練模型使用 torchvision 中的分類模型進行初始化。 這些模型預期為 Tensor[C, H, W] 的列表。 有關更多資訊,請查看模型的建構函式。

警告

偵測模組處於 Beta 階段,不保證向後相容性。

物件偵測

以下物件偵測模型可用,無論是否具有預訓練權重


以下是如何使用預訓練物件偵測模型的範例

from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()

預訓練模型輸出的類別可在 weights.meta["categories"] 中找到。 有關如何繪製模型邊界框的詳細資訊,您可以參考 實例分割模型

所有可用的物件偵測權重表

Box MAP 在 COCO val2017 上報告

權重 (Weight)

Box MAP

參數 (Params)

GFLOPS

配方 (Recipe)

FCOS_ResNet50_FPN_Weights.COCO_V1

39.2

32.3M

128.21

連結 (link)

FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1

22.8

19.4M

0.72

連結 (link)

FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1

32.8

19.4M

4.49

連結 (link)

FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1

46.7

43.7M

280.37

連結 (link)

FasterRCNN_ResNet50_FPN_Weights.COCO_V1

37

41.8M

134.38

連結 (link)

RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1

41.5

38.2M

152.24

連結 (link)

RetinaNet_ResNet50_FPN_Weights.COCO_V1

36.4

34.0M

151.54

連結 (link)

SSD300_VGG16_Weights.COCO_V1

25.1

35.6M

34.86

連結 (link)

SSDLite320_MobileNet_V3_Large_Weights.COCO_V1

21.3

3.4M

0.58

連結 (link)

實例分割

以下實例分割模型可用,無論是否具有預訓練權重


有關如何繪製模型遮罩的詳細資訊,您可以參考 實例分割模型

所有可用的實例分割權重表

Box 和 Mask MAP 在 COCO val2017 上報告

權重 (Weight)

Box MAP

Mask MAP

參數 (Params)

GFLOPS

配方 (Recipe)

MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1

47.4

41.8

46.4M

333.58

連結 (link)

MaskRCNN_ResNet50_FPN_Weights.COCO_V1

37.9

34.6

44.4M

134.38

連結 (link)

關鍵點偵測

以下人物關鍵點偵測模型可用,無論是否具有預訓練權重


預訓練模型輸出的類別可在 weights.meta["keypoint_names"] 中找到。 有關如何繪製模型邊界框的詳細資訊,您可以參考 可視化關鍵點

所有可用的關鍵點偵測權重表

Box 和 Keypoint MAP 在 COCO val2017 上報告

權重 (Weight)

Box MAP

Keypoint MAP

參數 (Params)

GFLOPS

配方 (Recipe)

KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY

50.6

61.1

59.1M

133.92

連結 (link)

KeypointRCNN_ResNet50_FPN_Weights.COCO_V1

54.6

65

59.1M

137.42

連結 (link)

影片分類

警告

影片模組處於 Beta 階段,不保證向後相容性。

以下影片分類模型可用,無論是否具有預訓練權重


這裡是如何使用預訓練視訊分類模型的一個範例

from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights

vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32]  # optionally shorten duration

# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")

預訓練模型輸出的類別可以在 weights.meta["categories"] 中找到。

所有可用的視訊分類權重表

準確率是使用單一裁剪(clip length 16)在 Kinetics-400 上報告的

權重 (Weight)

Acc@1

Acc@5

參數 (Params)

GFLOPS

配方 (Recipe)

MC3_18_Weights.KINETICS400_V1

63.96

84.13

11.7M

43.34

連結 (link)

MViT_V1_B_Weights.KINETICS400_V1

78.477

93.582

36.6M

70.6

連結 (link)

MViT_V2_S_Weights.KINETICS400_V1

80.757

94.665

34.5M

64.22

連結 (link)

R2Plus1D_18_Weights.KINETICS400_V1

67.463

86.175

31.5M

40.52

連結 (link)

R3D_18_Weights.KINETICS400_V1

63.2

83.479

33.4M

40.7

連結 (link)

S3D_Weights.KINETICS400_V1

68.368

88.05

8.3M

17.98

連結 (link)

Swin3D_B_Weights.KINETICS400_V1

79.427

94.386

88.0M

140.67

連結 (link)

Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1

81.643

95.574

88.0M

140.67

連結 (link)

Swin3D_S_Weights.KINETICS400_V1

79.521

94.158

49.8M

82.84

連結 (link)

Swin3D_T_Weights.KINETICS400_V1

77.715

93.519

28.2M

43.88

連結 (link)

光流 (Optical Flow)

以下是可用的光流模型,無論是否有預訓練

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深度教學

檢視教學

資源

尋找開發資源並獲得您問題的解答

檢視資源