• 文件 >
  • XLA 的量化操作 (實驗性功能)
快捷方式

XLA 的量化操作 (實驗性功能)

本文檔概述如何利用量化操作在 XLA 裝置上啟用量化。

XLA 量化操作為量化操作 (例如,分塊 int4 量化矩陣乘法) 提供高階抽象化。這些操作類似於 CUDA 生態系統中的量化 CUDA 核心 (範例),在 XLA 框架內提供類似的功能和效能優勢。

注意: 目前這被歸類為實驗性功能。其 API 細節將在下一個 (2.5) 版本中變更。

如何使用:

XLA 量化操作可以用作 torch op,或包裝 torch.optorch.nn.Module。這 2 個選項讓模型開發人員可以彈性選擇將 XLA 量化操作整合到其解決方案中的最佳方式。

torch opnn.Module 都與 torch.compile( backend='openxla') 相容。

在模型程式碼中呼叫 XLA 量化操作

使用者可以像呼叫其他常規 PyTorch 操作一樣呼叫 XLA 量化操作。這為將 XLA 量化操作整合到其應用程式中提供了最大的彈性。量化操作在 Eager 模式和 Dynamo 中都可運作,並支援常規 PyTorch CPU 張量和 XLA 張量。

注意 請查看量化操作的 docstring 以了解量化權重的佈局。

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_quantized_matmul

N_INPUT_FEATURES=10
N_OUTPUT_FEATURES=20
x = torch.randn((3, N_INPUT_FEATURES), dtype=torch.bfloat16)
w_int = torch.randint(-128, 127, (N_OUTPUT_FEATURES, N_INPUT_FEATURES), dtype=torch.int8)
scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)

# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)

device = xm.xla_device()
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)

# Call with XLA Tensor to run on XLA device
matmul_output_xla = torch.ops.xla.quantized_matmul(x_xla, w_int_xla, scaler_xla)

# Use with torch.compile(backend='openxla')
def f(x, w, s):
  return torch.ops.xla.quantized_matmul(x, w, s)

f_dynamo = torch.compile(f, backend="openxla")
dynamo_out_xla = f_dynamo(x_xla, w_int_xla, scaler_xla)

將量化操作包裝到模型開發人員模型程式碼中的自訂 nn.Module 中是很常見的作法

class MyQLinearForXLABackend(torch.nn.Module):
  def __init__(self):
    self.weight = ...
    self.scaler = ...

  def load_weight(self, w, scaler):
    # Load quantized Linear weights
    # Customized way to preprocess the weights
    ...
    self.weight = processed_w
    self.scaler = processed_scaler


  def forward(self, x):
    # Do some random stuff with x
    ...
    matmul_output = torch.ops.xla.quantized_matmul(x, self.weight, self.scaler)
    # Do some random stuff with matmul_output
    ...

模組交換

或者,使用者也可以使用包裝 XLA 量化操作的 nn.Module,並在模型程式碼中進行模組交換

orig_model = MyModel()
# Quantize the model and get quantized weights
q_weights = quantize(orig_model)
# Process the quantized weight to the format that XLA quantized op expects.
q_weights_for_xla = process_for_xla(q_weights)

# Do module swap
q_linear = XlaQuantizedLinear(self.linear.in_features,
                              self.linear.out_features)
q_linear.load_quantized_weight(q_weights_for_xla)
orig_model.linear = q_linear

支援的量化操作:

矩陣乘法

權重 激活 資料類型 支援
逐通道 (sym/asym) W8A16
逐通道 (sym/asym) 不適用 W8A8
逐通道 逐 token W8A8
逐通道 逐 token W4A8
分塊 (sym/asym) 不適用 W8A16
分塊 (sym/asym) 不適用 W8A16
分塊 逐 token W8A8
分塊 逐 token W4A8

注意 W[X]A[Y] 指的是權重為 X 位元,激活為 Y 位元。如果 X/Y 為 4 或 8,則指的是 int4/8。16 代表 bfloat16 格式。

文件

存取 PyTorch 的全面開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源