(prototype) FX 圖模式量化使用者指南¶
建立於:2021 年 8 月 20 日 | 最後更新:2023 年 12 月 12 日 | 最後驗證:2024 年 11 月 05 日
作者: Jerry Zhang
FX 圖模式量化需要可符號追蹤的模型。我們使用 FX 框架將可符號追蹤的 nn.Module 實例轉換為 IR,並在 IR 上操作以執行量化通道。請在 PyTorch 討論區發表您關於符號追蹤模型的問題
量化僅適用於模型中可符號追蹤的部分。資料相依的控制流程 - if 語句/for 迴圈等,使用符號追蹤的值 - 是一種常見但不支援的模式。如果您的模型不是端到端可符號追蹤的,您可以選擇以下幾個選項,僅在模型的一部分上啟用 FX 圖模式量化。您可以組合使用這些選項
- 不可追蹤的程式碼不需要量化
僅符號追蹤需要量化的程式碼
跳過符號追蹤不可追蹤的程式碼
- 不可追蹤的程式碼需要量化
重構您的程式碼,使其可符號追蹤
撰寫您自己的觀察和量化的子模組
如果不可符號追蹤的程式碼不需要量化,我們有以下兩個選項來執行 FX 圖模式量化
僅符號追蹤需要量化的程式碼¶
當整個模型不可符號追蹤,但我們想要量化的子模組可符號追蹤時,我們可以僅在該子模組上執行量化。
之前
class M(nn.Module):
def forward(self, x):
x = non_traceable_code_1(x)
x = traceable_code(x)
x = non_traceable_code_2(x)
return x
之後
class FP32Traceable(nn.Module):
def forward(self, x):
x = traceable_code(x)
return x
class M(nn.Module):
def __init__(self):
self.traceable_submodule = FP32Traceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# We'll only symbolic trace/quantize this submodule
x = self.traceable_submodule(x)
x = self.traceable_code_2(x)
return x
量化程式碼
qconfig_mapping = QConfigMapping().set_global(qconfig)
model_fp32.traceable_submodule = \
prepare_fx(model_fp32.traceable_submodule, qconfig_mapping, example_inputs)
請注意,如果需要保留原始模型,您需要在呼叫量化 API 之前自行複製它。
跳過符號追蹤不可追蹤的程式碼¶
當模組中有一些不可追蹤的程式碼,並且這部分程式碼不需要量化時,我們可以將這部分程式碼分解到一個子模組中,並跳過符號追蹤該子模組。
之前
class M(nn.Module):
def forward(self, x):
x = self.traceable_code_1(x)
x = non_traceable_code(x)
x = self.traceable_code_2(x)
return x
之後,不可追蹤的部分被移動到一個模組並標記為葉節點
class FP32NonTraceable(nn.Module):
def forward(self, x):
x = non_traceable_code(x)
return x
class M(nn.Module):
def __init__(self):
...
self.non_traceable_submodule = FP32NonTraceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# we will configure the quantization call to not trace through
# this submodule
x = self.non_traceable_submodule(x)
x = self.traceable_code_2(x)
return x
量化程式碼
qconfig_mapping = QConfigMapping.set_global(qconfig)
prepare_custom_config_dict = {
# option 1
"non_traceable_module_name": "non_traceable_submodule",
# option 2
"non_traceable_module_class": [MNonTraceable],
}
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict,
)
如果不可符號追蹤的程式碼需要量化,我們有以下兩個選項
重構您的程式碼,使其可符號追蹤¶
如果很容易重構程式碼並使程式碼可符號追蹤,我們可以重構程式碼並移除 Python 中不可追蹤的結構的使用。
有關符號追蹤支援的更多資訊,請參閱 此處。
之前
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
這不可符號追蹤,因為在 x.view(*new_x_shape) 中不支援解包,但是,很容易移除解包,因為 x.view 也支援列表輸入。
之後
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
這可以與其他方法結合,並且量化程式碼取決於模型。
撰寫您自己的觀察和量化的子模組¶
如果不可追蹤的程式碼無法重構為可符號追蹤,例如,它有一些無法消除的迴圈,例如 nn.LSTM,我們需要將不可追蹤的程式碼分解到一個子模組中(我們在 fx 圖模式量化中稱之為 CustomModule),並定義子模組的觀察和量化版本(在後訓練靜態量化或用於靜態量化的量化感知訓練中),或定義量化版本(在後訓練動態和僅權重量化中)
之前
class M(nn.Module):
def forward(self, x):
x = traceable_code_1(x)
x = non_traceable_code(x)
x = traceable_code_1(x)
return x
之後
1. 將 non_traceable_code 分解為 FP32NonTraceable 不可追蹤的邏輯,封裝在模組中
class FP32NonTraceable:
...
2. 定義 FP32NonTraceable 的觀察版本
class ObservedNonTraceable:
@classmethod
def from_float(cls, ...):
...
3. 定義 FP32NonTraceable 的靜態量化版本和一個類別方法 "from_observed",以從 ObservedNonTraceable 轉換為 StaticQuantNonTraceable
class StaticQuantNonTraceable:
@classmethod
def from_observed(cls, ...):
...
# refactor parent class to call FP32NonTraceable
class M(nn.Module):
def __init__(self):
...
self.non_traceable_submodule = FP32NonTraceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# this part will be quantized manually
x = self.non_traceable_submodule(x)
x = self.traceable_code_1(x)
return x
量化程式碼
# post training static quantization or
# quantization aware training (that produces a statically quantized module)v
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
FP32NonTraceable: ObservedNonTraceable,
}
},
}
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
校準/訓練(未顯示)
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedNonTraceable: StaticQuantNonTraceable,
}
},
}
model_quantized = convert_fx(
model_prepared,
convert_custom_config_dict)
後訓練動態/僅權重量化在這兩種模式下,我們不需要觀察原始模型,因此我們只需要定義 thee 量化模型
class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable
...
@classmethod
def from_observed(cls, ...):
...
prepare_custom_config_dict = {
"non_traceable_module_class": [
FP32NonTraceable
]
}
# The example is for post training quantization
model_fp32.eval()
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"dynamic": {
FP32NonTraceable: DynamicQuantNonTraceable,
}
},
}
model_quantized = convert_fx(
model_prepared,
convert_custom_config_dict)
您還可以在 torch/test/quantization/test_quantize_fx.py
中的測試 test_custom_module_class
中找到自訂模組的範例。