• 文件 >
  • Torch Export to StableHLO
快捷鍵

Torch Export to StableHLO

本文件說明如何使用 torch export + torch xla 匯出為 StableHLO 格式。

from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm
import torchvision
import torch

xla_device = xm.xla_device()

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
stablehlo_program = exported_program_to_stablehlo(exported)

# Now stablehlo_program is a callable backed by stablehlo IR.

# we can see it's stablehlo code with
#   here 'forward' is the name of function. Currently we only support
#   one entry point per program, but in the future we will support
#   multiple entry points in a program.
print(stablehlo_program.get_stablehlo_text('forward'))

# we can also print out the bytecode
print(stablehlo_program.get_stablehlo_bytecode('forward'))

# we can also run the module, to run the stablehlo module, we need to move
# our tensors to XLA device.
sample_input_xla = tuple(s.to(xla_device) for s in sample_input)

output2 = stablehlo_program(*sample_input_xla)
print(torch.allclose(output, output2.cpu(), atol=1e-5))

將 StableHLO 位元組碼儲存到磁碟

現在可以使用以下方式將 stablehlo 儲存到磁碟

stablehlo_program.save('/tmp/stablehlo_dir')

路徑應為空目錄的路徑。如果路徑不存在,則會建立路徑。此目錄可以再次載入為另一個 stablehlo_program

from torch_xla.stablehlo import StableHLOGraphModule
stablehlo_program2 = StableHLOGraphModule.load('/tmp/stablehlo_dir')
output3 = stablehlo_program2(*sample_input_xla)

轉換儲存的 StableHLO 以進行服務

StableHLO 是一種開放格式,並支援在 tensorflow.serving 模型伺服器中進行服務。但是,在將其提供給 tf.serving 之前,我們需要先將產生的 StableHLO 位元組碼包裝成 tf.saved_model 格式。

為此,首先請確保您在目前的 python 環境中安裝了最新的 tensorflow,如果沒有,請使用以下方式安裝

pip install tf-nightly

現在,您可以執行轉換器 (在 torch/xla 安裝中提供)

stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1

之後,您可以使用 tf serving 二進位檔在新產生的 tf.saved_model 上執行模型伺服器。

docker pull tensorflow/serving
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/resnet_tf,target=/models/resnet_tf \
-e MODEL_NAME=resnet_tf -t tensorflow/serving &

您也可以直接使用 tf.serving 二進位檔,而無需 docker。如需更多詳細資訊,請參閱 tf serving 指南

常見包裝函式

我想直接儲存 tf.saved_model 格式,而無需執行個別命令。

您可以使用此輔助函式來完成此操作

from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model

save_torch_module_as_tf_saved_model(
    resnet18,  # original pytorch torch.nn.Module
    sample_inputs, # sample inputs used to trace
    '/tmp/resnet_tf'   # directory for tf.saved_model
)

其他常見包裝函式

def save_as_stablehlo(exported_model: 'ExportedProgram',
                      stablehlo_dir: os.PathLike,
                      options: Optional[StableHLOExportOptions] = None):

save_as_stablehlo (也別名為 torch_xla.save_as_stablehlo) 接受 ExportedProgram 並將 StableHLO 儲存到磁碟。亦即,與 exported_program_to_stablehlo(…).save(…) 相同

def save_torch_model_as_stablehlo(
    torchmodel: torch.nn.Module,
    args: Tuple[Any],
    path: os.PathLike,
    options: Optional[StableHLOExportOptions] = None) -> None:
    """Convert a torch model to a callable backed by StableHLO.

接受 torch.nn.Module 並將 StableHLO 儲存到磁碟。亦即,與 torch.export.export,然後接著 save_as_stablehlo 相同

save_as_stablehlo 產生的檔案。

在上述範例的 /tmp/stablehlo_dir 內,您會找到 3 個目錄:dataconstantsfunctions。data 和 constants 都將包含程式使用的張量,並以 numpy.ndarray 格式儲存,使用 numpy.save

functions 目錄將包含 StableHLO 位元組碼 (在此命名為 forward.bytecode)、人類可讀取的 StableHLO 程式碼 (MLIR 格式) forward.mlir,以及一個 JSON 檔案,指定哪些權重和原始使用者的輸入會變成此 StableHLO 函式的哪些位置引數;以及每個引數的 dtypes 和形狀。

範例

$ find /tmp/stablehlo_dir
./functions
./functions/forward.mlir
./functions/forward.bytecode
./functions/forward.meta
./constants
./constants/3
./constants/1
./constants/0
./constants/2
./data
./data/L__fn___layers_15_feed_forward_w2.weight
./data/L__fn___layers_13_feed_forward_w1.weight
./data/L__fn___layers_3_attention_wo.weight
./data/L__fn___layers_12_ffn_norm_weight
./data/L__fn___layers_25_attention_wo.weight
...

JSON 檔案是 torch_xla.stablehlo.StableHLOFunc 類別的序列化形式。

此格式目前也處於原型階段,且不保證向後相容性。未來的計畫是標準化一個主要框架 (PyTorch、JAX、TensorFlow) 可以同意的格式。

透過產生 stablehlo.composite 在 StableHLO 中保留高階 PyTorch 運算

高階 PyTorch 運算 (例如 F.scaled_dot_product_attention) 將在 PyTorch -> StableHLO 降低期間分解為低階運算。在下游 ML 編譯器中擷取高階運算對於產生高效能、有效率的專用核心至關重要。雖然在 ML 編譯器中模式比對一堆低階運算可能具有挑戰性且容易出錯,但我們提供一種更穩健的方法,可在 StableHLO 程式中概述高階 PyTorch 運算 - 透過為高階 PyTorch 運算產生 stablehlo.composite

透過 StableHLOCompositeBuilder,使用者可以概述 torch.nn.Moduleforward 函式中的任意區域。然後在匯出的 StableHLO 程式中,將會產生概述區域的複合運算。

注意: 由於非張量輸入到概述區域的值將硬式編碼在匯出的圖表中,如果希望從下游編譯器中擷取這些值,請將這些值儲存為複合屬性。

以下範例顯示實際使用案例 - 擷取 scaled_product_attention

import torch
import torch.nn.functional as F
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder


class M(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(128, 128, bias=False)
        self.k_proj = torch.nn.Linear(128, 128, bias=False)
        self.v_proj = torch.nn.Linear(128, 128, bias=False)
        # Initialize the StableHLOCompositeBuilder with the name of the composite op and its attributes
        # Note: To capture the value of non-tensor inputs, please pass them as attributes to the builder
        self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25, "other_attr": "val"})

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q, k, v = self.b.mark_inputs(q, k, v)
        attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
        attn_out = self.b.mark_outputs(attn_out)
        attn_out = attn_out + x
        return attn_out

input_args = (torch.randn((10, 8, 128)), )
# torch.export to Exported Program
exported = torch.export.export(M(), input_args)
# Exported Program to StableHLO
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)

主要的 StableHLO 圖表如下所示

module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
    ...
    %10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
    %11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
    return %11 : tensor<10x8x128xf32>
  }

  func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
    // Actual implementation of the composite
    ...
    return %11 : tensor<10x8x128xf32>
  }

sdpa 運算會封裝為主圖表中的 stablehlo 複合呼叫。torch.nn.Module 中指定的名稱和屬性會傳播。

%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}

sdpa 運算的參考 PyTorch 分解會擷取在 StableHLO 函式中

func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
    // Actual implementation of the composite
    ...
    return %11 : tensor<10x8x128xf32>
  }

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源