• 教學 >
  • 為部署優化 Vision Transformer 模型
快捷方式

最佳化用於部署的 Vision Transformer 模型

建立於:2021 年 3 月 15 日 | 最後更新:2024 年 1 月 19 日 | 最後驗證:2024 年 11 月 05 日

Jeff Tang, Geeta Chauhan

Vision Transformer 模型應用了基於先進注意力機制的 Transformer 模型,該模型在自然語言處理中被引入,以在電腦視覺任務中實現各種最先進 (SOTA) 的結果。 Facebook 資料效率圖像變換器 DeiT 是一個在 ImageNet 上訓練用於圖像分類的 Vision Transformer 模型。

在本教學中,我們將首先介紹什麼是 DeiT 以及如何使用它,然後逐步完成編寫腳本、量化、優化以及在 iOS 和 Android 應用程式中使用模型的完整步驟。 我們還將比較量化、優化以及非量化、非優化模型的效能,並展示沿著這些步驟將量化和優化應用於模型的好處。

什麼是 DeiT

自 2012 年深度學習興起以來,卷積神經網路 (CNN) 一直是圖像分類的主要模型,但 CNN 通常需要數億張圖像進行訓練才能實現 SOTA 結果。 DeiT 是一個 Vision Transformer 模型,它需要更少的資料和計算資源進行訓練,才能與領先的 CNN 在執行圖像分類方面競爭,這要歸功於 DeiT 的兩個關鍵組件

  • 資料擴增,模擬在更大的資料集上進行訓練;

  • 原生蒸餾,允許 Transformer 網路從 CNN 的輸出中學習。

DeiT 表明,Transformer 可以成功應用於電腦視覺任務,並且對資料和資源的存取有限。 有關 DeiT 的更多詳細資訊,請參閱repopaper

使用 DeiT 分類圖像

請按照 DeiT 儲存庫中的 README.md 獲取有關如何使用 DeiT 分類圖像的詳細資訊,或者進行快速測試,首先安裝所需的套件

pip install torch torchvision timm pandas requests

若要在 Google Colab 中運行,請執行以下命令安裝相依性

!pip install timm pandas requests

然後執行下面的腳本

from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.6.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:

Importing from timm.models.registry is deprecated, please import via timm.models

/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:

Importing from timm.models.layers is deprecated, please import via timm.layers

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:

Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:

Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:

Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:

Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:

Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:

Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:

Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:

Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth

  0%|          | 0.00/330M [00:00<?, ?B/s]
  5%|4         | 15.6M/330M [00:00<00:02, 164MB/s]
 10%|#         | 34.4M/330M [00:00<00:01, 183MB/s]
 16%|#6        | 54.2M/330M [00:00<00:01, 194MB/s]
 22%|##2       | 74.2M/330M [00:00<00:01, 200MB/s]
 29%|##8       | 94.2M/330M [00:00<00:01, 203MB/s]
 35%|###4      | 114M/330M [00:00<00:01, 206MB/s]
 41%|####      | 134M/330M [00:00<00:00, 207MB/s]
 47%|####6     | 154M/330M [00:00<00:00, 190MB/s]
 52%|#####2    | 173M/330M [00:00<00:00, 190MB/s]
 58%|#####8    | 193M/330M [00:01<00:00, 196MB/s]
 64%|######4   | 213M/330M [00:01<00:00, 201MB/s]
 71%|#######   | 233M/330M [00:01<00:00, 203MB/s]
 76%|#######6  | 253M/330M [00:01<00:00, 193MB/s]
 83%|########2 | 273M/330M [00:01<00:00, 198MB/s]
 88%|########8 | 292M/330M [00:01<00:00, 195MB/s]
 94%|#########4| 312M/330M [00:01<00:00, 197MB/s]
100%|##########| 330M/330M [00:01<00:00, 197MB/s]
269

輸出應為 269,根據 ImageNet 類別索引的標籤檔案,對應於 timber wolf, grey wolf, gray wolf, Canis lupus

現在我們已經驗證了可以使用 DeiT 模型對圖像進行分類,讓我們看看如何修改模型,使其可以在 iOS 和 Android 應用程式上運行。

編寫 DeiT 腳本

若要在行動裝置上使用該模型,我們首先需要編寫該模型的腳本。 有關快速概述,請參閱編寫腳本和優化配方。 執行下面的程式碼,將上一步中使用的 DeiT 模型轉換為可以在行動裝置上運行的 TorchScript 格式。

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main

產生約 346MB 大小的編寫腳本的模型檔案 fbdeit_scripted.pt

量化 DeiT

為了在保持推論準確度大致不變的情況下顯著減小訓練模型的大小,可以將量化應用於模型。 由於 DeiT 中使用的 Transformer 模型,我們可以輕鬆地將動態量化應用於模型,因為動態量化最適合 LSTM 和 Transformer 模型(有關更多詳細資訊,請參閱這裡)。

現在執行下面的程式碼

# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/observer.py:229: UserWarning:

Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.

這會產生模型經過腳本化和量化後的版本 fbdeit_quantized_scripted.pt,大小約為 89MB,相較於未量化模型 346MB 的大小,減少了 74%!

您可以使用 scripted_quantized_model 來產生相同的推論結果

out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed
269

優化 DeiT

在行動裝置上使用量化且經過腳本化的模型之前的最後一個步驟是優化它

from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

產生的 fbdeit_optimized_scripted_quantized.pt 檔案大小與量化、腳本化但未優化的模型大致相同。推論結果保持不變。

out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed
269

使用 Lite Interpreter

為了了解 Lite Interpreter 可以帶來多少模型大小的縮減和推論速度的提升,讓我們建立模型的 Lite 版本。

optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

雖然 Lite 模型的大小與非 Lite 版本相當,但在行動裝置上執行 Lite 版本時,預期推論速度會有所提升。

比較推論速度

為了了解四種模型(原始模型、腳本化模型、量化且腳本化模型、優化且量化腳本化模型)的推論速度差異,請執行以下程式碼

with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 96.01ms
scripted model: 105.84ms
scripted & quantized model: 124.12ms
scripted & quantized & optimized model: 135.20ms
lite model: 147.87ms

在 Google Colab 上執行的結果如下

original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms

以下結果總結了每個模型所花費的推論時間,以及每個模型相對於原始模型的減少百分比。

import pandas as pd
import numpy as np

df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
    ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
    ["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
    columns=['Inference Time', 'Reduction'])], axis=1)

print(df)

"""
        Model                             Inference Time    Reduction
0   original model                             1236.69ms           0%
1   scripted model                             1226.72ms        0.81%
2   scripted & quantized model                  593.19ms       52.03%
3   scripted & quantized & optimized model      598.01ms       51.64%
4   lite model                                  600.72ms       51.43%
"""
                                    Model  ... Reduction
0                          original model  ...        0%
1                          scripted model  ...   -10.25%
2              scripted & quantized model  ...   -29.28%
3  scripted & quantized & optimized model  ...   -40.82%
4                              lite model  ...   -54.02%

[5 rows x 3 columns]

'\n        Model                             Inference Time    Reduction\n0\toriginal model                             1236.69ms           0%\n1\tscripted model                             1226.72ms        0.81%\n2\tscripted & quantized model                  593.19ms       52.03%\n3\tscripted & quantized & optimized model      598.01ms       51.64%\n4\tlite model                                  600.72ms       51.43%\n'

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題的解答

檢視資源