• 文件 >
  • Torch-TensorRT Dynamo 後端
捷徑

Torch-TensorRT Dynamo 後端

本指南介紹 Torch-TensorRT Dynamo 後端,它使用 TensorRT 以 Ahead-Of-Time 方式優化 PyTorch 模型。

使用 Dynamo 後端

PyTorch 2.1 引入了 torch.export API,可以將圖形從 PyTorch 程式匯出到 ExportedProgram 物件。Torch-TensorRT Dynamo 後端會編譯這些 ExportedProgram 物件,並使用 TensorRT 優化它們。以下是 Dynamo 後端的一個簡單用法

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda()]
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
trt_gm(*inputs)

備註

torch_tensorrt.dynamo.compile 是使用者與 Torch-TensorRT Dynamo 後端互動的主要 API。模型的輸入類型應為 ExportedProgram(理想情況下為 torch.export.exporttorch_tensorrt.dynamo.trace 的輸出(將在下節中討論)),而輸出類型為 torch.fx.GraphModule 物件。

可自訂的設定

使用者可以使用許多選項自訂使用 TensorRT 進行優化的設定。以下是一些常用的選項

  • inputs - 對於靜態形狀,這可以是 torch 張量或 torch_tensorrt.Input 物件的清單。對於動態形狀,這應該是 torch_tensorrt.Input 物件的清單。

  • enabled_precisions - TensorRT 建置器在優化過程中可以使用的精度集。

  • truncate_long_and_double - 將 long 和 double 值分別截斷為 int 和 float。

  • torch_executed_ops - 強制由 Torch 執行的運算子。

  • min_block_size - 作為 TensorRT 區段執行所需的最小連續運算子數。

可以在 這裡 找到選項的完整清單

備註

我們目前在 Dynamo 中不支援 INT 精度。目前在

我們的 Torchscript IR 中支援此功能。我們計劃在下一個版本中為 Dynamo 實作類似的支援。

內部機制

在內部,torch_tensorrt.dynamo.compile 會對圖形執行以下操作。

  • 降低層級 - 應用降低層級傳遞以新增/移除運算子,以進行最佳轉換。

  • 分割區 - 根據 min_block_sizetorch_executed_ops 欄位,將圖形分割為 PyTorch 和 TensorRT 區段。

  • 轉換 - 在此階段,PyTorch 運算子會轉換為 TensorRT 運算子。

  • 優化 - 轉換後,我們會建置 TensorRT 引擎,並將其嵌入 PyTorch 圖形中。

追蹤

torch_tensorrt.dynamo.trace 可用於追蹤 PyTorch 圖形並產生 ExportedProgram。這會在內部對運算子執行一些分解,以進行下游優化。然後,可以使用 torch_tensorrt.dynamo.compile API 使用 ExportedProgram。如果您的模型中有動態輸入形狀,則可以使用此 torch_tensorrt.dynamo.trace 匯出具有動態形狀的模型。或者,您也可以直接使用帶有約束的 torch.export

import torch
import torch_tensorrt

inputs = [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
                              opt_shape=(4, 3, 224, 224),
                              max_shape=(8, 3, 224, 224),
                              dtype=torch.float32)]
model = MyModel().eval()
exp_program = torch_tensorrt.dynamo.trace(model, inputs)

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源