• 文件 >
  • 儲存使用 Torch-TensorRT 編譯的模型
快捷方式

儲存使用 Torch-TensorRT 編譯的模型

可以使用 torch_tensorrt.save API 完成儲存使用 Torch-TensorRT 編譯的模型。

Dynamo IR

預設情況下,Torch-TensorRT 的 ir=dynamo 編譯的輸出類型是 torch.fx.GraphModule 物件。我們可以透過指定 output_format 標誌,以 TorchScript (torch.jit.ScriptModule) 或 ExportedProgram (torch.export.ExportedProgram) 格式儲存此物件。以下是 output_format 將接受的選項

  • exported_program:這是預設值。我們先對 graphmodule 執行轉換,然後使用 torch.export.save 儲存模組。

  • torchscript:我們透過 torch.jit.trace 追蹤 graphmodule,並透過 torch.jit.save 儲存它。

a) ExportedProgram

以下是一個使用範例

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)

# Later, you can load it and run inference
model = torch.export.load("trt.ep").module()
model(*inputs)

b) Torchscript

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)

# Later, you can load it and run inference
model = torch.jit.load("trt.ts").cuda()
model(*inputs)

Torchscript IR

在 Torch-TensorRT 1.X 版本中,使用 Torchscript IR 是編譯並執行 Torch-TensorRT 推論的主要方法。對於 ir=ts,此行為在 2.X 版本中也保持不變。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)

載入模型

我們可以使用 PyTorch 的 torch.jit.loadtorch.export.load API 直接載入 torchscript 或 exported_program 模型。另外,我們提供了一個輕量級的封裝函式 torch_tensorrt.load(file_path),它可以載入上述任一種模型類型。

以下是一個使用範例

import torch
import torch_tensorrt

# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源