Torch Export to StableHLO¶
本文件說明如何使用 torch export + torch xla 匯出為 StableHLO 格式。
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm
import torchvision
import torch
xla_device = xm.xla_device()
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
stablehlo_program = exported_program_to_stablehlo(exported)
# Now stablehlo_program is a callable backed by stablehlo IR.
# we can see it's stablehlo code with
# here 'forward' is the name of function. Currently we only support
# one entry point per program, but in the future we will support
# multiple entry points in a program.
print(stablehlo_program.get_stablehlo_text('forward'))
# we can also print out the bytecode
print(stablehlo_program.get_stablehlo_bytecode('forward'))
# we can also run the module, to run the stablehlo module, we need to move
# our tensors to XLA device.
sample_input_xla = tuple(s.to(xla_device) for s in sample_input)
output2 = stablehlo_program(*sample_input_xla)
print(torch.allclose(output, output2.cpu(), atol=1e-5))
將 StableHLO 位元組碼儲存到磁碟¶
現在可以使用以下方式將 stablehlo 儲存到磁碟
stablehlo_program.save('/tmp/stablehlo_dir')
路徑應為空目錄的路徑。如果路徑不存在,則會建立路徑。此目錄可以再次載入為另一個 stablehlo_program
from torch_xla.stablehlo import StableHLOGraphModule
stablehlo_program2 = StableHLOGraphModule.load('/tmp/stablehlo_dir')
output3 = stablehlo_program2(*sample_input_xla)
轉換儲存的 StableHLO 以進行服務¶
StableHLO 是一種開放格式,並支援在 tensorflow.serving 模型伺服器中進行服務。但是,在將其提供給 tf.serving 之前,我們需要先將產生的 StableHLO 位元組碼包裝成 tf.saved_model
格式。
為此,首先請確保您在目前的 python 環境中安裝了最新的 tensorflow
,如果沒有,請使用以下方式安裝
pip install tf-nightly
現在,您可以執行轉換器 (在 torch/xla 安裝中提供)
stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1
之後,您可以使用 tf serving 二進位檔在新產生的 tf.saved_model
上執行模型伺服器。
docker pull tensorflow/serving
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/resnet_tf,target=/models/resnet_tf \
-e MODEL_NAME=resnet_tf -t tensorflow/serving &
您也可以直接使用 tf.serving
二進位檔,而無需 docker。如需更多詳細資訊,請參閱 tf serving 指南。
常見包裝函式¶
我想直接儲存 tf.saved_model 格式,而無需執行個別命令。¶
您可以使用此輔助函式來完成此操作
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
save_torch_module_as_tf_saved_model(
resnet18, # original pytorch torch.nn.Module
sample_inputs, # sample inputs used to trace
'/tmp/resnet_tf' # directory for tf.saved_model
)
其他常見包裝函式¶
def save_as_stablehlo(exported_model: 'ExportedProgram',
stablehlo_dir: os.PathLike,
options: Optional[StableHLOExportOptions] = None):
save_as_stablehlo
(也別名為 torch_xla.save_as_stablehlo
) 接受 ExportedProgram 並將 StableHLO 儲存到磁碟。亦即,與 exported_program_to_stablehlo(…).save(…) 相同
def save_torch_model_as_stablehlo(
torchmodel: torch.nn.Module,
args: Tuple[Any],
path: os.PathLike,
options: Optional[StableHLOExportOptions] = None) -> None:
"""Convert a torch model to a callable backed by StableHLO.
接受 torch.nn.Module
並將 StableHLO 儲存到磁碟。亦即,與 torch.export.export,然後接著 save_as_stablehlo 相同
save_as_stablehlo
產生的檔案。¶
在上述範例的 /tmp/stablehlo_dir
內,您會找到 3 個目錄:data
、constants
、functions
。data 和 constants 都將包含程式使用的張量,並以 numpy.ndarray
格式儲存,使用 numpy.save
。
functions 目錄將包含 StableHLO 位元組碼 (在此命名為 forward.bytecode
)、人類可讀取的 StableHLO 程式碼 (MLIR 格式) forward.mlir
,以及一個 JSON 檔案,指定哪些權重和原始使用者的輸入會變成此 StableHLO 函式的哪些位置引數;以及每個引數的 dtypes 和形狀。
範例
$ find /tmp/stablehlo_dir
./functions
./functions/forward.mlir
./functions/forward.bytecode
./functions/forward.meta
./constants
./constants/3
./constants/1
./constants/0
./constants/2
./data
./data/L__fn___layers_15_feed_forward_w2.weight
./data/L__fn___layers_13_feed_forward_w1.weight
./data/L__fn___layers_3_attention_wo.weight
./data/L__fn___layers_12_ffn_norm_weight
./data/L__fn___layers_25_attention_wo.weight
...
JSON 檔案是 torch_xla.stablehlo.StableHLOFunc
類別的序列化形式。
此格式目前也處於原型階段,且不保證向後相容性。未來的計畫是標準化一個主要框架 (PyTorch、JAX、TensorFlow) 可以同意的格式。
透過產生 stablehlo.composite
在 StableHLO 中保留高階 PyTorch 運算¶
高階 PyTorch 運算 (例如 F.scaled_dot_product_attention
) 將在 PyTorch -> StableHLO 降低期間分解為低階運算。在下游 ML 編譯器中擷取高階運算對於產生高效能、有效率的專用核心至關重要。雖然在 ML 編譯器中模式比對一堆低階運算可能具有挑戰性且容易出錯,但我們提供一種更穩健的方法,可在 StableHLO 程式中概述高階 PyTorch 運算 - 透過為高階 PyTorch 運算產生 stablehlo.composite。
透過 StableHLOCompositeBuilder
,使用者可以概述 torch.nn.Module
的 forward
函式中的任意區域。然後在匯出的 StableHLO 程式中,將會產生概述區域的複合運算。
注意: 由於非張量輸入到概述區域的值將硬式編碼在匯出的圖表中,如果希望從下游編譯器中擷取這些值,請將這些值儲存為複合屬性。
以下範例顯示實際使用案例 - 擷取 scaled_product_attention
import torch
import torch.nn.functional as F
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.q_proj = torch.nn.Linear(128, 128, bias=False)
self.k_proj = torch.nn.Linear(128, 128, bias=False)
self.v_proj = torch.nn.Linear(128, 128, bias=False)
# Initialize the StableHLOCompositeBuilder with the name of the composite op and its attributes
# Note: To capture the value of non-tensor inputs, please pass them as attributes to the builder
self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25, "other_attr": "val"})
def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q, k, v = self.b.mark_inputs(q, k, v)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = self.b.mark_outputs(attn_out)
attn_out = attn_out + x
return attn_out
input_args = (torch.randn((10, 8, 128)), )
# torch.export to Exported Program
exported = torch.export.export(M(), input_args)
# Exported Program to StableHLO
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)
主要的 StableHLO 圖表如下所示
module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
...
%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
%11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
return %11 : tensor<10x8x128xf32>
}
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
// Actual implementation of the composite
...
return %11 : tensor<10x8x128xf32>
}
sdpa 運算會封裝為主圖表中的 stablehlo 複合呼叫。torch.nn.Module 中指定的名稱和屬性會傳播。
%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}
sdpa 運算的參考 PyTorch 分解會擷取在 StableHLO 函式中
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
// Actual implementation of the composite
...
return %11 : tensor<10x8x128xf32>
}