捷徑

torch.onnx

概述

Open Neural Network eXchange (ONNX) 是一種用於表示機器學習模型的開放標準格式。torch.onnx 模組從原生 PyTorch torch.nn.Module 模型中擷取計算圖,並將其轉換為 ONNX 圖

匯出的模型可以被任何支援 ONNX 的 runtime 使用,包括 Microsoft 的 ONNX Runtime

您可以使用兩種 ONNX 匯出器 API,如下所示。 兩者都可以透過函式 torch.onnx.export() 呼叫。下一個範例展示如何匯出一個簡單的模型。

import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 128, 5)

    def forward(self, x):
        return torch.relu(self.conv1(x))

input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)

model = MyModel()

torch.onnx.export(
    model,                  # model to export
    (input_tensor,),        # inputs of the model,
    "my_model.onnx",        # filename of the ONNX model
    input_names=["input"],  # Rename inputs for the ONNX model
    dynamo=True             # True or False to select the exporter to use
)

接下來的章節將介紹匯出器的兩個版本。

基於 TorchDynamo 的 ONNX 匯出器

基於 TorchDynamo 的 ONNX 匯出器是 PyTorch 2.1 及更新版本最新的 (且為 Beta 版) 匯出器

TorchDynamo 引擎被用來嵌入 Python 的 frame evaluation API,並動態地將其位元組碼重寫成 FX Graph。產生的 FX Graph 會先經過潤飾,然後最終被轉換成 ONNX graph。

這種方法的主要優點是,FX graph 是透過位元組碼分析來捕獲的,這種分析保留了模型的動態特性,而不是使用傳統的靜態追蹤技術。

了解更多關於基於 TorchDynamo 的 ONNX Exporter

基於 TorchScript 的 ONNX Exporter

基於 TorchScript 的 ONNX exporter 自 PyTorch 1.2.0 起可用

TorchScript 被用來追蹤(透過 torch.jit.trace())模型並捕獲靜態計算圖。

因此,產生的圖有幾個限制

  • 它不記錄任何控制流,例如 if 語句或迴圈;

  • 不處理 trainingeval 模式之間的細微差異;

  • 不能真正處理動態輸入

為了嘗試支援靜態追蹤的限制,exporter 也支援 TorchScript scripting(透過 torch.jit.script()),它增加了對數據相關控制流的支援,例如。 然而,TorchScript 本身是 Python 語言的一個子集,因此並非所有 Python 中的功能都受支援,例如原地操作 (in-place operations)。

了解更多關於基於 TorchScript 的 ONNX Exporter

貢獻 / 開發

ONNX exporter 是一個社群專案,我們歡迎大家貢獻。 我們遵循 PyTorch 貢獻指南,但您可能也會有興趣閱讀我們的 開發 wiki

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得初學者和高級開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源