• 文件 >
  • torchao 貢獻者指南
捷徑

torchao 貢獻者指南

目標

在這份文件中,我們將討論 (1). 不同最佳化技術在 torchao 中的結構 (2). 如何為 torchao 貢獻

注意:此文件目前主要關注推論,但我們計劃在未來擴展到涵蓋訓練技術。

torchao 堆疊概述

首先,我們想介紹 torchao 堆疊

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
        Quantized Tensors (derived dtypes): AffineQuantizedTensor, CodebookQuantizedTensor
---------------------------------------------------------------------------------------------
  Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
            Basic dtypes: uint1-uint7, int1-int8, float3-float8

任何量化演算法都將使用上述堆疊中的某些組件,例如 int4_weight_only 量化使用:(1) 僅權重量化流程 (2) tinygemm bf16 激活 + int4 權重核心量化原始運算 (3) 具有 TensorCoreTiledLayoutAffineQuantizedTensor 張量子類別 (4) torch.uint4 dtype(目前以 quant_min/quant_max 模擬)

注意:我們還將在量化張量部分討論如何將稀疏性與量化結合

基本資料類型

dtype 是一個有點超載的術語,基本 dtype 是指在沒有任何額外元數據的情況下有意義的 dtype(例如,當人們調用 torch.empty(.., dtype) 時有意義),更多詳細資訊請查看:dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833

無論我們做什麼量化,最終我們都將使用一些低精度 dtype 來表示量化資料,我們旨在在 torchao 中支援的 dtype 有

  • torch.uint1torch.uint8 在 pytorch 2.3 及更高版本中可用

  • torch.int1torch.int8 在 pytorch 2.6 及更高版本中可用

  • torch.float3_e2_m0torch.float4_e2_m1torch.float4_e3_m0torch.float5_e2_m2torch.float5_e3_m1torch.float6_e2_m3torch.float6_e3_m2torch.float8_e4m3fntorch.float8_e5m2torch.float8_e4m3fnuztorch.float8_e5m2fnuz(float8 已添加到 torch 中,如果 float4 和 float6 變得流行,我們也計劃將它們添加到 torch 中)

注意:以上某些目前僅為原型。 當它們變得流行並獲得硬體支援時,我們將考慮將它們添加到 pytorch 核心中。

目前支援

在實際實作方面,有兩個部分:1). 在 PyTorch 中,我們需要將 dtype 添加到 torch.dtype,例如 torch.uint2,範例:pytorch/pytorch#117208,但這些只是佔位符,以便我們可以使用 torch.uint2。 2). 在 PyTorch 之外(例如在 torchao 中),我們使用張量子類別實作這些 dtype 的張量運算,也需要一個標準的封裝格式。

在 PyTorch 中新增佔位符 dtype

如同 dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 中提及的,在 PyTorch 中新增 dtype 的標準是其已被廣泛採用。對於上述基本 dtype,PyTorch 支援的有:

  • torch.uint1torch.uint8, torch.int1torch.int8, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz

對於其他類型,我們計劃等到有更多證據表明其已被廣泛採用且具備硬體支援後再考慮。

使用 Tensor 子類實現這些 dtype 的張量運算

對此,要求是我們確定一個「標準」封裝格式,並希望這個格式易於高效實作,但對於 uintx 和 floatx,我們尚未整合足夠的 kernel 來決定。因此,目前的封裝實作並非最終版本。在更多 uintx、intx 和 floatx kernel 被整合到 torchao 後,我們可以重新審視。

將 Tensor 子類整合到 PyTorch 原生工廠函數

之後,我們可以將工廠函數與 Tensor 子類連接,例如:torch.empty(..., dtype=torch.int4, ...) 可以創建一個 Int4Tensor Tensor 子類,其封裝格式在前一步中已確定。

量化原始運算 (Quantization Primitive Ops)

量化原始運算指的是用於在低精度量化張量和高精度張量之間轉換的運算子。我們主要有以下量化原始運算子: choose_qparams op:根據原始張量選擇量化參數,通常用於動態量化,例如仿射量化的 scale 和 zero_point; quantize op:根據量化參數,將原始高精度張量量化為前一節中提到的低精度張量; dequantize op:根據量化參數,將低精度張量反量化為高精度張量。

可能存在上述運算子的變體,以適應特定的使用情況。例如,對於靜態量化,我們可能有 choose_qparams_affine_with_min_max,它將根據從觀察過程中獲得的最小值/最大值來選擇量化參數。

高效的 Kernel

我們還將擁有與低精度張量配合使用的高效 kernel,例如

_weight_int4pack_mm,tinygemm int4 kernel (bf16 activation + int4 weight); int_matmul,接受兩個 int8 張量並輸出 int32 張量; int_scaled_matmul,執行 matmul 並將 scale 應用於結果。

注意:我們也可以依賴 torch.compile 來生成 kernel(透過 triton)。例如,目前的 int8 權重量化 kernel 僅依賴 torch.compile 來獲得加速。在這種情況下,沒有與量化類型相對應的特定「高效 kernel」。

量化張量(衍生 dtype)

在基本 dtype、量化原始運算和高效 kernel 的基礎上,我們可以將所有東西組合在一起,並透過繼承 torch.Tensor 來建立量化(低精度)張量。這個張量可以從高精度張量和一些參數來構建,這些參數可以配置使用者想要的特定量化方式。我們也可以稱其為衍生 dtype,因為它可以表示為基本 dtype 的張量和一些額外的 metadata(例如 scale)。

torchao 中的現有範例是 AffineQuantizedTensor,表示低精度張量是透過仿射映射從高精度張量量化的,即:low_precision_val = high_precision_val / scale + zero_point,其中 scale/zero_point 是量化參數,可以透過量化原始運算或透過某些最佳化程序來計算。仿射量化是一種非常常見的量化類型,因為當我們嘗試將較高精度值映射到較低精度值時,我們進行仿射變換 (high_preicsion_val / scale + zero_point)。另一種常見的量化類型,尤其是在較低位寬 (例如低於 4 位) 的情況下,是基於碼本/查閱表的量化。

Layout 和 TensorImpl

原生張量具有硬編碼的 layout 選擇清單,最常見的是 strided layout,它提供了 storage 的 strided、多維視圖,我們還有一些 sparse 和 mkldnn layout。

sparse COO tensor 為例,它具有 torch.sparse_coo layout,以及 SparseTensorImpl,後者會更改張量的儲存方式。

將張量封裝成不同格式的想法與 layout 的概念非常吻合,這就是我們希望重用它的原因。我們可以將 Layout 用於不同類型的封裝格式,將 TensorImpl 用於不同的儲存格式實作。可以在 Python 層級的張量子類中新增新的 TensorImpl,用於以封裝格式儲存張量,而無需修改 C++ PyTorch 核心程式碼。

例如,對於 _weight_int4pack_mm,我們需要將權重封裝成對 Tensor Core 友好的格式,我們稱其為 TensorCoreTiledLayout。我們為量化張量新增一個 tensor_impl 來儲存封裝 (或解封裝) 的權重,並使用 layout 來儲存與封裝相關的不同參數。

class AffineQuantizedTensor(...):
  # tensor_impl is also implemented with tensor subclass
  tensor_impl: torch.Tensor

  # to not conflict with existing layout property, we use `_layout`
  @property
  def _layout(self) -> Layout:
      return self.tensor_impl._layout

請注意,layout 不僅僅是自訂資料表示的抽象,它也用於 TensorImpl 如何與不同運算子互動。例如,相同的資料表示在運行相同的運算子時可以有不同的實作 (例如 transpose、quantized_linear),但運算子的語義應該保持不變。

量化 + Sparse Tensor 也可以透過 Layout 抽象來支援,例如 int4 weight only quantization + sparse。我們還提供了一些通用工具,可以幫助人們為量化張量新增不同的 layout,請查看下面的開發者指南以獲取程式碼範例。

量化演算法/流程 (Quantization Algorithms/Flows)

堆疊的最上層將會是最終的量化演算法和量化流程。傳統上,我們有權重量化 (weight only quantization)、動態量化 (dynamic quantization) 和靜態量化 (static quantization),但現在我們也看到越來越多類型的量化出現。

為了示範,假設經過先前的步驟,我們已經定義了 AffineQuantizedTensorto_affine_quantized 工廠函數。為了簡化,假設 to_affine_quantized 接受一個高精度浮點 Tensor 和一個 target_dtype (例如 torch.int8),並將其轉換為具有相應 dtype 的 AffineQuantizedTensor

注意:以下所有內容都是為了解釋概念,關於我們提供的 utils 和範例的更詳細介紹,可以在 Tensor Subclass Developer Guide 章節中找到。

權重量化

這是最簡單的量化形式,並且很容易將權重量化應用於模型,特別是當我們有了量化 Tensor (Quantized Tensor) 時。我們需要做的就是:

linear_module.weight = torch.nn.Parameter(to_affine_quantized_intx(linear_module.weight, …), requires_grad=False))

將上述內容應用於模型中的所有線性模組,我們將得到一個權重量化的模型。

動態 Activation 和權重量化

這以前被稱為「動態量化」,但它的意思是我們在運行時動態地量化 activation,並且也量化權重。與權重量化相比,主要問題是如何將量化應用於 activation。在 torchao 中,我們使用的常見模式是在量化的權重之上應用 to_linear_activation_quantized

quantized_weight = to_affine_quantized(linear_module.weight) activation_and_weight_quantized = to_linear_activation_quantized(quantized_weight) linear_module.weight = torch.nn.Parameter(activation_and_weight_quantized, requires_grad=False))

to_linear_activation_quantized 用於將量化應用於 activation,它接受一個 input_quant_func,該函數將量化 activation 和原始權重,並且在運行時,當它遇到 F.linear 運算時,它將把儲存的 input_qunat_func 應用於 activation,並重新分派到具有量化 activation 和權重的 F.linear

如果上述方法不起作用,使用者也可以進行模組交換,或使用 torch.fx.symbolic_trace() 取得一個可以修改的追蹤模組。

但使用 tensor 子類是首選,因為它更容易序列化/反序列化。如果我們使用 tensor 子類來支援動態量化,那麼我們可以直接加載量化的權重,而無需進一步準備模型。否則,我們需要在加載量化的權重之前,先進行模組交換或其他修改。

靜態 Activation 量化和權重量化

靜態量化意味著 activation 是靜態量化的,而不是在運行時動態量化的。在流程方面,靜態量化需要使用樣本資料進行校準,以便我們可以確定適當的量化參數。

在高層次上,靜態量化有三個步驟:(1)插入觀察者 (observers)(2)校準(3)量化模型

插入觀察者

在插入觀察者步驟中,我們需要將觀察者模組添加到運算子的輸入(和輸出)activation 和權重中,以收集 Tensor 的統計資訊。因此,我們需要解決兩個問題,如何定義觀察者模組?如何將觀察者模組添加到模型中。

如何定義觀察者模組

觀察者特定於:(1)量化類型(例如,仿射量化、基於查詢表的量化)(2)我們要追蹤的統計資訊類型,例如,最小最大值觀察者 (min max observer)、移動平均觀察者 (moving average observer)。

通常,觀察者模組應定義 forwardcalculate_qparams

對於仿射量化,我們定義了 AffineQuantizedMinMaxObserver,它根據仿射量化的粒度記錄 min_val/max_val,並且還定義了如何根據記錄的統計資訊來 calculate_qparams。

如何將觀察者模組添加到模型中
  1. 使用 Tensor 子類。如果您唯一感興趣量化的運算符是 linear,則可以使用 linear activation weight observer。我們還有一個對應的 insert_observer_ API,它可以處理修改 linear 的權重。

  2. 模組交換。或者,您也可以定義和 ObservedLinear 模組(或其他模組類型),並將未觀察到的模組與觀察到的模組交換。

校準

校準步驟通常很簡單,通常我們只需要通過校準資料集來運行模型即可。對於更複雜的校準(例如,我們記錄所有輸入並根據所有輸入進行優化),我們將在下一節中介紹其中一些。

量化

我們可以重複使用 quantize_ API,但提供一個不同的 apply_tensor_subclass 函數,該函數將觀察到的線性模組轉換為具有量化權重和靜態量化輸入 activation 的線性模組,這可以以與動態量化相同的方式完成(使用 to_linear_activation_quantized),請參閱範例

或者,使用者也可以進行模組交換

其他量化流程

對於不適合上述任何一種的其他量化流程/演算法,我們也打算為常見模式提供範例。例如,類似 GPTQ 的量化流程,已被 Autoround 採用,它使用 MultiTensor 和模組鉤子來優化模組。

如果您正在開發一種新的量化演算法/流程,並且不確定如何在 PyTorch 原生方式中實現它,請隨時開啟一個 issue 來描述您的演算法如何運作,我們可以幫助您提供有關實作細節的建議。

訓練

上述流程主要著重於推論,但低位元 dtype 張量也能用於訓練。

感知量化訓練

待辦事項(TODO)

低位元優化器

目前我們有一些低位元優化器的原型:main/torchao/prototype/low_bit_optim,它實現了一種特定類型的 4 位元、8 位元和 float8,並且可以與 FSDP 組合(使用查找表量化)。

量化訓練

與低位元優化器類似,我們在 main/torchao/prototype/quantized_training 中有量化訓練原型,並且我們也可以擴展 AffineQuantizedTensor 以支援訓練,初步啟用正在進行中,但需要大量的後續工作,包括使其適用於不同的核心等。

您還可以查看 量化訓練的教學,其中討論了如何使 dtype 張量子類可訓練。

案例研究:int4 權重純量化如何在 torchao 中運作?

為了將所有內容連接在一起,以下是關於 int4 權重純量化如何在 torchao 中實現的更詳細的逐步說明。

高層次摘要

::
量化流程:quantize_(model, int4_weight_only())
  • 發生了什麼:linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight), requires_grad=False)

  • 量化基本運算:呼叫 choose_qparams 和 quantize_affine 來量化張量

  • 量化的張量將會是 AffineQuantizedTensor,一種具有衍生 dtype(例如,帶有縮放比例和零點的 int4)的量化張量

  • 封裝運算 _convert_weight_to_int4pack 用於封裝量化的權重以實現高效執行

模型執行期間:model(input)
  • 在輸入和封裝的權重上呼叫 torch.ops.aten._weight_int4pack_mm

量化期間

首先,我們從 API 呼叫開始:quantize_(model, int4_weight_only()) 它的作用是將模型中 nn.Linear 模組的權重轉換為 int4 量化張量(AffineQuantizedTensor,它是 int4 dtype,非對稱,每個群組量化),使用 tinygemm 核心的佈局:tensor_core_tiled 佈局。

  • quantize_:模型級別的 API,通過應用來自用戶的轉換函數(第二個參數)來量化線性的權重

  • int4_weight_only:返回一個函數的函數,該函數將線性的權重轉換為 int4 權重純量化的權重 * 呼叫量化基本運算,如 choose_qparams_affine 和 quantize_affine 來量化張量

  • TensorCoreTiledLayout:tensor core tiled 佈局類型,存儲封裝格式的參數

  • TensorCoreTiledAQTTensorImpl:tensor core tiled TensorImpl,存儲封裝的權重以實現高效的 int4 權重純量化核心(tinygemm 核心)

模型執行期間

當我們運行量化模型 model(inputs) 時,我們將通過 nn.Linear 中的函數式線性運算符運行:

return F.linear(input, weight, bias)

其中 input 是一個 bfloat16 張量,weight 是一個 int4 AffineQuantizedTensor,它調用到 AffineQuantizedTensor 子類的 __torch_function__ 中,最終會調用 F.linear 的實現,當其中一個輸入是 AffineQuantizedTensor 時,因此它調用:

return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

_quantized_linear_op 遍歷 _AQT_QLINEAR_DISPATCH_TABLE 並檢查每個分派條件,如果分派條件通過,它將使用 input/weight/bias 調用實現。請查看 此文檔 以獲取 dispatch_conditionimpl 的解釋。

int4 權重純量化 dispatch_condition 檢查輸入是否為 bfloat16 張量,權重是否為 uint4 AffineQuantizedTensor wint4 權重純量化 核心實現 採用 bfloat16 輸入張量和 int4 AffineQuantizedTensor,並使用輸入張量和存儲在 weight_tensor.tensor_impl 中的封裝權重呼叫 torch.ops.aten._weight_int4pack_mm

儲存/載入期間

由於 AffineQuantizedTensor 權重仍然是一個 torch.Tensor,因此儲存/載入的工作方式與原始高精度浮點模型相同。 更多詳細資訊,請參閱序列化文檔

張量子類開發者指南

在前一節中,我們介紹了高階概述以及所有元件如何連接在一起。本節將重點介紹 Tensor 子類別,這是我們賴以提供彈性的主要擴充點,以便支援使用低精度 Tensor 進行推論、訓練和微調,並與 torch.compile、autograd 以及這些情境中的分散式 primitives 組合。

先決條件

一些外部可用的 Tensor 子類別資源

為何選擇 Tensor 子類別?

人們可以使用多種方法來實作量化技術或新的 dtype。我們建議使用基於 Tensor 子類別方法的主要動機有三點:(1). 將量化建模為 dtype 轉換是很自然的,因此使用 Tensor 子類別實作它意味著我們不會引入新的概念,而是重用 PyTorch 核心中已存在的現有概念,如 dtype、layout (2). 由於 Tensor 子類別在 torch function 或 aten ops 層級攔截計算,只要使用相同的功能/運算符,我們就能夠量化模型。這允許使用原生模組變體的模型(例如,稍微修改過的 nn.Linear 版本)仍然與量化相容。(3). Tensor 子類別也是 sparsity 和 distributed 等其他技術所採用的方法,因此使用 Tensor 子類別實作量化或 dtype 轉換將使其更容易與這些技術組合。

新 DType 的範例程式碼

請隨時從 教學 開始,這是一個端到端的工作範例,它結合了我們討論的所有內容,然後返回此文件以取得澄清和文件。

基本結構

Tensor 子類別需要定義一些基本方法:__new____init____tensor_flatten____tensor_unflatten__,以及 torch functions 的 dispatch functions __torch_function__ 和 aten ops 的 __torch_dispatch__

以下是一個基本結構的範例:

# 查閱 https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437 中的文件 from torchao.utils import TorchAOBaseTensor

class MyDTypeLayout(TorchAOBaseTensor)

# 查看教學程式碼以取得詳細資訊 pass

class MyDtypeTensor(TorchAOBaseTensor)

“””我們需要定義 __new__ 以建構新的 Tensor 子類別實例,並定義 __init__ 以初始化實例。此處對引數列表的外觀沒有要求,唯一的要求是 __new__ 必須使用 torch.Tensor._make_wrapper_subclass(cls, shape, …) 呼叫來傳回 Tensor 實例“”” @staticmethod def __new__(

cls, tensor_impl: MyDTypeLayout, shape: torch.Size, dtype: Optional[torch.dtype] = None,

):

… return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(

self, tensor_impl: MyDTypeLayout, shape: torch.Size, …

):

self.tensor_impl = tensor_impl

“””__tensor_flatten____tensor_unflatten__ 用於將 Tensor 拆解為原生 Tensors/屬性,並從拆解後的 Tensor 和屬性重建 Tensor 子類別實例,這些都是定義 Tensor 子類別以支援 torch.compile 所必需的“”” def __tensor_flatten__(self)

return [“tensor_impl”], [self.shape]

“””查看 https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 以取得 outer_size 和 outer_stride 的文件“”” @classmethod def __tensor_unflatten__(

cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride

):
tensor_impl = tensor_data_dict[“tensor_impl”]

shape, = tensor_attributes return cls(

tensor_impl, shape if outer_size is None else outer_size,

)

“””classmethod,將浮點 Tensor (fp32/fp16/bf16) 轉換為目前 dtype“””

@classmethod
def from_float(

cls, input_float: torch.Tensor,

):

mapping_type = MappingType.SYMMETRIC block_size = input_float.shape dtype = torch.int16 scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = (input_float / scale).to(torch.int8) tensor_impl = MyDTypeLayout.from_plain(int_data, scale) return cls(tensor_impl, input_float.shape)

“””[Optional] 查看 量化 Tensors 區段下的 Layout/Packing 文件,以了解 layout_type 是什麼“”” @property def _layout(self) -> LayoutType

return self.tensor_impl._layout

“””我們可以使用兩個入口點來修改 PyTorch 運算子的行為:torch_function 和 torch_dispatch

__torch_function__: 每當在 Tensor 物件上呼叫 torch 層級函式時,都會呼叫此方法,例如:torch.nn.functional.linear、tensor.detach、tensor.reshape、tensor.t 等。

__torch_dispatch__: 當在 Tensor 物件上呼叫 aten 運算符時,將在 C++ 調度器中呼叫此方法,例如:aten.mm、aten.addmm、aten.detach.default、aten.t.default 等。您可以查看 https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 以了解 __torch_function____torch_dispatch__ 正在做什麼,但是透過 TorchAoBaseTensor,使用者可以直接使用一些 helper functions(請參閱下一節)

運算符支援

有兩種運算符支援類型:torch function 和 aten ops。對於 torch functions(例如 torch.nn.functional.linear),我們需要在 Tensor 子類別中覆寫 __torch_function__ 回呼函數;對於 aten ops(例如 torch.ops.aten.mm),我們需要覆寫 __torch_dispatch__ 回呼函數。

對於新的 dtype,我們希望人們定義以下 decorator:

如果您的 dtype 類別是從 torchao.utils.TorchAoBaseTensor 繼承的,您可以執行

implements = my_dtype_tensor_cls.implements

我們可以透過以下方式實作運算符調度:

# torch.nn.functional.linear 的 torch_function 調度範例 def _quantized_linear_op(input_tensor, weight_tensor, bias)

if isinstance(input_tensor, MyDtypeTensor)

input_tensor = input_tensor.dequantize()

if isinstance(weight_tensor, MyDtypeTensor)

weight_tensor = weight_tensor.dequantize()

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements(torch.nn.functional.linear) def _(*args, **kwargs)

input_tensor, weight_tensor, bias = (

args[0], args[1], args[2] if len(args) > 2 else None,

) # 這裡使用 try/except,以便在 _quantized_linear_op 中的任何調度路徑都未選取 input_tensor/weight_tensor 時,我們能有一個通用的回退方案。這讓我們能更容易理解 _quantized_linear_op 中的分支。 try

return _quantized_linear_op(input_tensor, weight_tensor, bias)

except NotImplementedError
if isinstance(input_tensor, MyDtypeTensor)

input_tensor = input_tensor.dequantize()

if isinstance(weight_tensor, MyDtypeTensor)

weight_tensor = weight_tensor.dequantize()

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

# aten op 調度的範例,針對 aten.detach.default @implements(aten.detach.default) def _(func, *args, **kwargs)

# wrapper tensor __torch_dispatch__ 子類如果想要 # 使用 torch.compile,則應該使用 return_and_correct_aliasing。它確保子類正確地實現每個 op 的別名行為,# 這對於 AOTAutograd 的正確性至關重要。

# _apply_fn_to_data 僅將函數應用於 args[0] 中的 tensor 資料, args[0]my_dtype 的 tensor 子類 # return return_and_correct_aliasing(

func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)

)

我們需要覆寫哪些 ops? 這取決於我們嘗試量化的模型,通常覆寫的 ops 有: __torch_function__: torch.nn.functional.linear __torch_dispatch__: torch.ops.aten.addmm.default, torch.ops.aten.mm.default, torch.ops.aten.detach.default, torch.ops.aten.t.default

您也可以使用以下程式碼在 __torch_function____torch_dispatch__ 中找到可以覆寫的 ops,您可以從想要最佳化的模型開始,首先覆寫像 linear 這樣重要的 ops,然後逐漸擴大覆蓋範圍,直到測試執行並且您獲得預期的最佳化程式碼(有關更多詳細資訊,請參閱「最佳化運算子」部分):
class M(torch.nn.Module)
def __init__(self) -> None

super().__init__() self.linear = torch.nn.Linear(10, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor

return self.linear(x) + x

from torch.overrides import TorchFunctionMode class TorchFunctionLoggingMode(TorchFunctionMode)

def __torch_function__(cls, func, types, args=(), kwargs=None)
if kwargs is None

kwargs = {}

print(f”TORCH_FUNC={str(func)}”) return func(*args, **kwargs)

with TorchFunctionLoggingMode()

m(*example_inputs)

## 範例輸出 # TORCH_FUNC=<built-in function linear> # TORCH_FUNC=<method ‘add’ of ‘torch._C.TensorBase’ objects>

from torch.utils._python_dispatch import TorchDispatchMode class TorchDispatchLoggingMode(TorchDispatchMode)

def __torch_dispatch__(cls, func, types, args=(), kwargs=None)
if kwargs is None

kwargs = {}

print(f”ATEN_FUNC={str(func)}”) return func(*args, **kwargs)

with TorchDispatchLoggingMode()

m(*example_inputs)

## 範例輸出 # ATEN_FUNC=aten.t.default # ATEN_FUNC=aten.addmm.default # ATEN_FUNC=aten.add.Tensor

# 或者,針對 torch_dispatch (aten) ops 提供更完善的日誌記錄: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py

或者,您可以執行測試範例(例如,將您的量化模型與 tensor 並行處理、FSDP 等一起使用),並發現缺少的 ops 並添加它們,直到測試通過。

我們仍在努力製作一個表格,說明每個功能需要支援哪些運算子。

添加高效核心

自訂 Triton 核心

自訂 triton 核心可以在 torchao/kernel 中實作和註冊。

您可能還需要定義自己的 autotuner

手寫自訂核心

可以通過 torchao/csrc 實作 cpu/cuda/mps 的自訂核心 (實作),例如 int4 cuda,並且可以通過 torch.ops.my_custom_op 存取。

調度

對於調度到 cpu/cuda/mps 裝置的最佳化核心,我們可以在 __torch_function____torch_dispatch__ 中檢查調度條件,並調度到目標運算子。例如,bfloat16 啟動和 uint4 權重核心的條件可以在 這裡 找到。

特別是對於 AffineQuantizedTensor,我們也允許人們擴展量化的 linear 以使用新的高效核心,或者通過定義兩個函數來實作: dispatch_condition (定義調度到核心的條件) 和 impl (實際實作,它接收 activation, (quantized) weight, bias Tensor 並執行高效核心),兩者都將 input_tensor, weight_tensor, bias 作為參數,並且可以使用 register_aqt_quantized_linear_dispatch 註冊到 AffineQuantizedTensor 中的量化 linear 的調度。 這裡 是一個範例,展示了它是如何工作的。

Layout/TensorImpl

有時必須打包量化的權重才能產生最佳效能。這可以使用 layout 來抽象。 有關完整範例,請參閱 這裡

流程

在實作 tensor 子類之後,我們也可以將其封裝到 factory 函數中,例如:

# 從浮點 tensor 轉換為我的 dtype tensor 子類 to_my_dtype = MyDTypeTensor.from_float

對於模型層級的 API,人們可以重複使用 torchao.quantize_,它允許將 tensor 子類轉換應用於線性層的權重,並允許使用篩選函數來選擇要應用 tensor 子類轉換的模組。

請參閱「量化演算法/流程」章節,以取得基於工廠函數的權重專用/動態量化/靜態量化和其他類型模型層級 API 的範例。

使用 torch.compile 提升效能

請注意:對於 pytorch 2.4 及更早版本,我們需要使用以下內容:

from torchao.utils import unwrap_tensor_subclass m_unwrapped = unwrap_tensor_subclass(m)

為了與 torch.compile 相容,以達到效能最佳化,我們應首先使用 fullgraph=True 執行 torch.compile,並移除任何不必要的圖形中斷。 您可以新增 TORCH_LOGS="output_code",以便在執行指令碼時查看 inductor 產生的程式碼。例如:TORCH_LOGS="output_code" python example.py

model = torch.compile(model, mode=”max-autotune”, fullgraph=True)

序列化

請查看序列化文件以獲取更多詳細資訊。

注意

我們已與 huggingface transformer 整合,並支援透過 huggingface save_pretrained/push_to_hub/from_pretrained API 進行序列化/反序列化:https://huggingface.co/docs/transformers/main/en/quantization/torchao

注意

另一個範例可以在與 diffuser 的整合中找到:https://github.com/sayakpaul/diffusers-torchao/blob/main/inference/serialization_and_loading.md

其他功能支援

以上僅說明基本功能支援,我們還提供範例說明如何透過擴展MyDTypeTensor來新增對訓練、張量平行、FSDP 的支援,我們將在developer_api_guide資料夾中放入更多涵蓋以下使用案例的範例。

擴展 torchao 的一般指南

對於新的使用案例,例如訓練 dtype (如 fp4 訓練),可以從在原型資料夾 torchao/prototype 中新增 tensor 子類開始,但如果 AffineQuantizedTensor 大部分支援您想要做的事情,您也可以查看它,例如,為完全相同的仿射量化新增 int3 核心。 如果您對特定新使用案例應該做什麼有疑問,請隨時開啟一個 issue。

貢獻現有程式碼庫

Tensor 子類功能/可組合性測試

我們還在開發測試套件,以測試 tensor 子類的功能以及與不同系統 (如 torch.compile、DTensor 等) 的可組合性 (我們建議複製並貼上測試,並適應於測試您自己的 tensor 子類)。

核心微基準測試

在我們測試模型效能之前,我們也可以對具有不同輸入維度的單個線性運算符 (或其他計算密集型/記憶體密集型) 運算符進行一些微基準測試,以了解加速效果。 對於您要進行基準測試的特定核心,您可以建立一個基準測試檔案,例如 benchmarks/benchmark_aq.py,並使用對目標模型重要的不同形狀執行基準測試。 取得線性運算和其他運算的相關形狀的快速方法是使用 範例執行。

使用您感興趣的最佳化模型變更模型,並執行以下操作:

python tutorials/developer_api_guide/print_op_and_shapes.py

範例輸出:

TORCH_FUNC=<built-in function linear> (M, K, N): 10 10 10 TORCH_FUNC=<method ‘add’ of ‘torch._C.TensorBase’ objects> args[0] shape: torch.Size([10, 10])

所有線性形狀 (M, K, N):[(10, 10, 10)]

所有線性形狀的輸出都可以複製並貼到 benchmarks/benchmark_your_kernel.py 下的微基準測試指令碼程式碼中,以進行基準測試。

對於基準測試輔助函數,目前我們有 12,請隨時使用其中一個,但我們未來可能會保留一個。

模型基準測試和評估

實施量化流程後,您可以在已經修改為對 torch.compile 友好的 llama (llama2/llama3) 或 sam 模型上執行基準測試和評估,並與 torchao 中的現有技術進行比較。

注意:llama 模型 (llama2/llama3) 是我們記憶體密集型模型的代表模型,而 sam 是我們計算密集型模型的代表模型。

請查看每個指令碼的 --help 選項,以了解支援的選項,例如,您可以使用 --profile=profile_path 取得執行的 chrome 追蹤,以了解詳細的 chrome 追蹤

如果有任何您認為適合加入 torchao 模型基準測試/評估資料夾的新重要模型,請告訴我們。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源