(prototype) FX 圖形模式後訓練靜態量化¶
建立於:2021 年 2 月 8 日 | 最後更新:2025 年 1 月 24 日 | 最後驗證:2024 年 11 月 5 日
作者:Jerry Zhang 編輯者:Charles Hernandez
本教學介紹基於 torch.fx 在圖形模式下執行後訓練靜態量化的步驟。 FX 圖形模式量化的優點是我們可以在模型上完全自動地執行量化。 雖然可能需要一些工作才能使模型與 FX 圖形模式量化相容(可使用 torch.fx
符號追蹤),但我們將提供一個單獨的教學來展示如何使我們要量化的模型部分與 FX 圖形模式量化相容。 我們還有一個關於 FX 圖形模式後訓練動態量化的教學。 總之; FX 圖形模式 API 如下所示
import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping
float_model.eval()
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
example_inputs = (next(iter(data_loader))[0]) # get an example input
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # fuse modules and insert observers
calibrate(prepared_model, data_loader_test) # run calibration on sample data
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
1. FX 圖形模式量化的動機¶
目前,PyTorch 只有 eager 模式量化作為替代方案:PyTorch 中使用 Eager 模式的靜態量化。
我們可以觀察到 eager 模式量化過程中涉及多個手動步驟,包括
顯式量化和反量化激活 - 當模型中混合使用浮點和量化運算時,這非常耗時。
顯式融合模組 - 這需要手動識別卷積、批次正規化和 relu 以及其他融合模式的序列。
需要對 pytorch 張量運算(例如 add、concat 等)進行特殊處理。
Functionals 沒有一流的支援(functional.conv2d 和 functional.linear 將不會被量化)
大多數這些修改來自 eager 模式量化的底層限制。 Eager 模式在模組級別工作,因為它無法檢查實際執行的程式碼(在 forward 函數中),量化是透過模組交換實現的,並且我們不知道模組在 eager 模式的 forward 函數中是如何使用的,因此它需要使用者手動插入 QuantStub 和 DeQuantStub 以標記他們想要量化或反量化的點。 在圖形模式下,我們可以檢查 forward 函數中已執行的實際程式碼(例如 aten 函數呼叫),並且量化是透過模組和圖形操作實現的。 由於圖形模式具有已執行程式碼的完全可見性,因此我們的工具能夠自動計算出諸如要融合哪些模組以及在何處插入觀察者呼叫、量化/反量化函數等,我們能夠自動化整個量化過程。
FX 圖形模式量化的優點是
簡單的量化流程,最少的手動步驟
解鎖了執行更高等級優化的可能性,例如自動精度選擇
2. 定義輔助函數並準備資料集¶
我們將從執行必要的導入、定義一些輔助函數並準備資料開始。 這些步驟與 PyTorch 中使用 Eager 模式的靜態量化相同。
若要使用完整的 ImageNet 資料集執行本教學中的程式碼,請先按照此處的說明下載 imagenet ImageNet Data。 將下載的檔案解壓縮到 ‘data_path’ 資料夾中。
下載 torchvision resnet18 模型並將其重新命名為 data/resnet18_pretrained_float.pth
。
import os
import sys
import time
import numpy as np
import torch
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms
# Set up warnings
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module=r'.*'
)
warnings.filterwarnings(
action='default',
module=r'torch.ao.quantization'
)
# Specify random seed for repeatable results
_ = torch.manual_seed(191009)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(model, criterion, data_loader):
model.eval()
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
cnt = 0
with torch.no_grad():
for image, target in data_loader:
output = model(image)
loss = criterion(output, target)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
print('')
return top1, top5
def load_model(model_file):
model = resnet18(pretrained=False)
state_dict = torch.load(model_file, weights_only=True)
model.load_state_dict(state_dict)
model.to("cpu")
return model
def print_size_of_model(model):
if isinstance(model, torch.jit.RecursiveScriptModule):
torch.jit.save(model, "temp.p")
else:
torch.jit.save(torch.jit.script(model), "temp.p")
print("Size (MB):", os.path.getsize("temp.p")/1e6)
os.remove("temp.p")
def prepare_data_loaders(data_path):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=train_batch_size,
sampler=train_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=eval_batch_size,
sampler=test_sampler)
return data_loader, data_loader_test
data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'resnet18_pretrained_float.pth'
train_batch_size = 30
eval_batch_size = 50
data_loader, data_loader_test = prepare_data_loaders(data_path)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
float_model.eval()
# create another instance of the model since
# we need to keep the original model around
model_to_quantize = load_model(saved_model_dir + float_model_file).to("cpu")
4. 使用 QConfigMapping
指定如何量化模型¶
qconfig_mapping = QConfigMapping.set_global(default_qconfig)
我們使用與 eager 模式量化相同的 qconfig,qconfig
只是激活和權重的觀察器的命名元組。 QConfigMapping
包含從 ops 到 qconfigs 的對應資訊
qconfig_mapping = (QConfigMapping()
.set_global(qconfig_opt) # qconfig_opt is an optional qconfig, either a valid qconfig or None
.set_object_type(torch.nn.Conv2d, qconfig_opt) # can be a callable...
.set_object_type("reshape", qconfig_opt) # ...or a string of the method
.set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig_opt) # matched in order, first match takes precedence
.set_module_name("foo.bar", qconfig_opt)
.set_module_name_object_type_order()
)
# priority (in increasing order): global, object_type, module_name_regex, module_name
# qconfig == None means fusion and quantization should be skipped for anything
# matching the rule (unless a higher priority match is found)
與 qconfig
相關的工具函式可以在 qconfig 檔案中找到,而與 QConfigMapping
相關的工具函式則可以在 qconfig_mapping <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/qconfig_mapping_utils.py> 中找到。
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
5. 為訓練後靜態量化準備模型¶
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
prepare_fx 將 BatchNorm 模組摺疊到先前的 Conv2d 模組中,並在模型中的適當位置插入觀察者 (observers)。
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
print(prepared_model.graph)
6. 校準 (Calibration)¶
校準函式會在觀察者插入模型後執行。校準的目的是執行一些具有代表性的工作負載範例 (例如,訓練資料集的一個樣本),以便模型中的觀察者能夠觀察 Tensors 的統計資訊,並且我們稍後可以使用此資訊來計算量化參數。
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
calibrate(prepared_model, data_loader_test) # run calibration on sample data
7. 將模型轉換為量化模型¶
convert_fx
接收一個經過校準的模型,並產生一個量化模型。
quantized_model = convert_fx(prepared_model)
print(quantized_model)
8. 評估 (Evaluation)¶
我們現在可以印出量化模型的大小和準確度。
print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
fx_graph_mode_model_file_path = saved_model_dir + "resnet18_fx_graph_mode_quantized.pth"
# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path, weights_only=False)
# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path), weights_only=True)
# save with script
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)
top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
print("[after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
如果您想要獲得更好的準確度或效能,請嘗試更改 qconfig_mapping。我們計劃在 Numerical Suite 中新增對 graph 模式的支援,以便您可以輕鬆確定模型中不同模組對於量化的敏感度。有關更多資訊,請參閱 PyTorch Numerical Suite 教學
9. 偵錯量化模型¶
我們還可以印出量化和非量化卷積運算的權重,以查看差異,我們將首先顯式呼叫 fuse 以融合模型中的卷積和批次正規化:請注意,fuse_fx
僅在 eval 模式下有效。
fused = fuse_fx(float_model)
conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]
print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))
10. 與基準浮點模型和 Eager 模式量化的比較¶
scripted_float_model_file = "resnet18_scripted.pth"
print("Size of baseline model")
print_size_of_model(float_model)
top1, top5 = evaluate(float_model, criterion, data_loader_test)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)
在本節中,我們將使用 FX graph 模式量化的模型與在 eager 模式下量化的模型進行比較。FX graph 模式和 eager 模式產生非常相似的量化模型,因此預期準確度和加速效果也相似。
print("Size of Fx graph mode quantized model")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print_size_of_model(eager_quantized_model)
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
print("eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)
我們可以發現 FX graph 模式和 eager 模式量化模型的模型大小和準確度非常相似。
在 AIBench 中執行模型 (使用單執行緒) 會得到以下結果
Scripted Float Model:
Self CPU time total: 192.48ms
Scripted Eager Mode Quantized Model:
Self CPU time total: 50.76ms
Scripted FX Graph Mode Quantized Model:
Self CPU time total: 50.63ms
正如我們所看到的,對於 resnet18,FX graph 模式和 eager 模式量化模型都比浮點模型獲得了相似的加速,大約比浮點模型快 2-4 倍。但是,與浮點模型相比,實際的加速效果可能會因模型、裝置、建置、輸入批次大小、執行緒等因素而異。