注意
點擊這裡下載完整的範例程式碼
最佳化用於部署的 Vision Transformer 模型¶
建立於:2021 年 3 月 15 日 | 最後更新:2024 年 1 月 19 日 | 最後驗證:2024 年 11 月 05 日
Vision Transformer 模型應用了基於先進注意力機制的 Transformer 模型,該模型在自然語言處理中被引入,以在電腦視覺任務中實現各種最先進 (SOTA) 的結果。 Facebook 資料效率圖像變換器 DeiT 是一個在 ImageNet 上訓練用於圖像分類的 Vision Transformer 模型。
在本教學中,我們將首先介紹什麼是 DeiT 以及如何使用它,然後逐步完成編寫腳本、量化、優化以及在 iOS 和 Android 應用程式中使用模型的完整步驟。 我們還將比較量化、優化以及非量化、非優化模型的效能,並展示沿著這些步驟將量化和優化應用於模型的好處。
什麼是 DeiT¶
自 2012 年深度學習興起以來,卷積神經網路 (CNN) 一直是圖像分類的主要模型,但 CNN 通常需要數億張圖像進行訓練才能實現 SOTA 結果。 DeiT 是一個 Vision Transformer 模型,它需要更少的資料和計算資源進行訓練,才能與領先的 CNN 在執行圖像分類方面競爭,這要歸功於 DeiT 的兩個關鍵組件
資料擴增,模擬在更大的資料集上進行訓練;
原生蒸餾,允許 Transformer 網路從 CNN 的輸出中學習。
DeiT 表明,Transformer 可以成功應用於電腦視覺任務,並且對資料和資源的存取有限。 有關 DeiT 的更多詳細資訊,請參閱repo和paper。
使用 DeiT 分類圖像¶
請按照 DeiT 儲存庫中的 README.md
獲取有關如何使用 DeiT 分類圖像的詳細資訊,或者進行快速測試,首先安裝所需的套件
pip install torch torchvision timm pandas requests
若要在 Google Colab 中運行,請執行以下命令安裝相依性
!pip install timm pandas requests
然後執行下面的腳本
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.6.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:
Importing from timm.models.registry is deprecated, please import via timm.models
/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:
Importing from timm.models.layers is deprecated, please import via timm.layers
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:
Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:
Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:
Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:
Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:
Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:
Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:
Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:
Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00<?, ?B/s]
5%|4 | 15.6M/330M [00:00<00:02, 164MB/s]
10%|# | 34.4M/330M [00:00<00:01, 183MB/s]
16%|#6 | 54.2M/330M [00:00<00:01, 194MB/s]
22%|##2 | 74.2M/330M [00:00<00:01, 200MB/s]
29%|##8 | 94.2M/330M [00:00<00:01, 203MB/s]
35%|###4 | 114M/330M [00:00<00:01, 206MB/s]
41%|#### | 134M/330M [00:00<00:00, 207MB/s]
47%|####6 | 154M/330M [00:00<00:00, 190MB/s]
52%|#####2 | 173M/330M [00:00<00:00, 190MB/s]
58%|#####8 | 193M/330M [00:01<00:00, 196MB/s]
64%|######4 | 213M/330M [00:01<00:00, 201MB/s]
71%|####### | 233M/330M [00:01<00:00, 203MB/s]
76%|#######6 | 253M/330M [00:01<00:00, 193MB/s]
83%|########2 | 273M/330M [00:01<00:00, 198MB/s]
88%|########8 | 292M/330M [00:01<00:00, 195MB/s]
94%|#########4| 312M/330M [00:01<00:00, 197MB/s]
100%|##########| 330M/330M [00:01<00:00, 197MB/s]
269
輸出應為 269,根據 ImageNet 類別索引的標籤檔案,對應於 timber wolf, grey wolf, gray wolf, Canis lupus
。
現在我們已經驗證了可以使用 DeiT 模型對圖像進行分類,讓我們看看如何修改模型,使其可以在 iOS 和 Android 應用程式上運行。
編寫 DeiT 腳本¶
若要在行動裝置上使用該模型,我們首先需要編寫該模型的腳本。 有關快速概述,請參閱編寫腳本和優化配方。 執行下面的程式碼,將上一步中使用的 DeiT 模型轉換為可以在行動裝置上運行的 TorchScript 格式。
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main
產生約 346MB 大小的編寫腳本的模型檔案 fbdeit_scripted.pt
。
量化 DeiT¶
為了在保持推論準確度大致不變的情況下顯著減小訓練模型的大小,可以將量化應用於模型。 由於 DeiT 中使用的 Transformer 模型,我們可以輕鬆地將動態量化應用於模型,因為動態量化最適合 LSTM 和 Transformer 模型(有關更多詳細資訊,請參閱這裡)。
現在執行下面的程式碼
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/observer.py:229: UserWarning:
Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
這會產生模型經過腳本化和量化後的版本 fbdeit_quantized_scripted.pt
,大小約為 89MB,相較於未量化模型 346MB 的大小,減少了 74%!
您可以使用 scripted_quantized_model
來產生相同的推論結果
269
優化 DeiT¶
在行動裝置上使用量化且經過腳本化的模型之前的最後一個步驟是優化它
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
產生的 fbdeit_optimized_scripted_quantized.pt
檔案大小與量化、腳本化但未優化的模型大致相同。推論結果保持不變。
269
使用 Lite Interpreter¶
為了了解 Lite Interpreter 可以帶來多少模型大小的縮減和推論速度的提升,讓我們建立模型的 Lite 版本。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
雖然 Lite 模型的大小與非 Lite 版本相當,但在行動裝置上執行 Lite 版本時,預期推論速度會有所提升。
比較推論速度¶
為了了解四種模型(原始模型、腳本化模型、量化且腳本化模型、優化且量化腳本化模型)的推論速度差異,請執行以下程式碼
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 96.01ms
scripted model: 105.84ms
scripted & quantized model: 124.12ms
scripted & quantized & optimized model: 135.20ms
lite model: 147.87ms
在 Google Colab 上執行的結果如下
original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms
以下結果總結了每個模型所花費的推論時間,以及每個模型相對於原始模型的減少百分比。
import pandas as pd
import numpy as np
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
"""
Model Inference Time Reduction
0 original model 1236.69ms 0%
1 scripted model 1226.72ms 0.81%
2 scripted & quantized model 593.19ms 52.03%
3 scripted & quantized & optimized model 598.01ms 51.64%
4 lite model 600.72ms 51.43%
"""
Model ... Reduction
0 original model ... 0%
1 scripted model ... -10.25%
2 scripted & quantized model ... -29.28%
3 scripted & quantized & optimized model ... -40.82%
4 lite model ... -54.02%
[5 rows x 3 columns]
'\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'