torch.onnx¶
概述¶
Open Neural Network eXchange (ONNX) 是一種用於表示機器學習模型的開放標準格式。torch.onnx
模組從原生 PyTorch torch.nn.Module
模型中擷取計算圖,並將其轉換為 ONNX 圖。
匯出的模型可以被任何支援 ONNX 的 runtime 使用,包括 Microsoft 的 ONNX Runtime。
您可以使用兩種 ONNX 匯出器 API,如下所示。 兩者都可以透過函式 torch.onnx.export()
呼叫。下一個範例展示如何匯出一個簡單的模型。
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 128, 5)
def forward(self, x):
return torch.relu(self.conv1(x))
input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)
model = MyModel()
torch.onnx.export(
model, # model to export
(input_tensor,), # inputs of the model,
"my_model.onnx", # filename of the ONNX model
input_names=["input"], # Rename inputs for the ONNX model
dynamo=True # True or False to select the exporter to use
)
接下來的章節將介紹匯出器的兩個版本。
基於 TorchDynamo 的 ONNX 匯出器¶
基於 TorchDynamo 的 ONNX 匯出器是 PyTorch 2.1 及更新版本最新的 (且為 Beta 版) 匯出器
TorchDynamo 引擎被用來嵌入 Python 的 frame evaluation API,並動態地將其位元組碼重寫成 FX Graph。產生的 FX Graph 會先經過潤飾,然後最終被轉換成 ONNX graph。
這種方法的主要優點是,FX graph 是透過位元組碼分析來捕獲的,這種分析保留了模型的動態特性,而不是使用傳統的靜態追蹤技術。
基於 TorchScript 的 ONNX Exporter¶
基於 TorchScript 的 ONNX exporter 自 PyTorch 1.2.0 起可用
TorchScript 被用來追蹤(透過 torch.jit.trace()
)模型並捕獲靜態計算圖。
因此,產生的圖有幾個限制
它不記錄任何控制流,例如 if 語句或迴圈;
不處理
training
和eval
模式之間的細微差異;不能真正處理動態輸入
為了嘗試支援靜態追蹤的限制,exporter 也支援 TorchScript scripting(透過 torch.jit.script()
),它增加了對數據相關控制流的支援,例如。 然而,TorchScript 本身是 Python 語言的一個子集,因此並非所有 Python 中的功能都受支援,例如原地操作 (in-place operations)。