• 文件 >
  • 在 Python 中使用 Torch-TensorRT
捷徑

在 Python 中使用 Torch-TensorRT

與僅支援 TorchScript 編譯的 CLI 和 C++ API 相比,Torch-TensorRT Python API 支援許多獨特的用例。

Torch-TensorRT Python API 可以接受 torch.nn.Moduletorch.jit.ScriptModuletorch.fx.GraphModule 作為輸入。根據提供的內容,將選擇兩個前端(TorchScript 或 FX)中的一個來編譯模組。如果提供的模組類型受支援,使用者可以使用 compileir 旗標明確設定他們想要使用哪個前端。如果給定 torch.nn.Moduleir 旗標設定為 defaulttorchscript,則模組將透過 torch.jit.script 執行,以將輸入模組轉換為 TorchScript 模組。

若要使用 Torch-TensorRT 編譯您的輸入 torch.nn.Module,您只需要將模組和輸入提供給 Torch-TensorRT,您將會收到一個最佳化的 TorchScript 模組,以供執行或新增到另一個 PyTorch 模組中。輸入是一個 torch_tensorrt.Input 類別的清單,這些類別定義了輸入張量的形狀、資料類型和記憶體格式。或者,如果您的輸入是更複雜的資料類型,例如張量的元組或清單,您可以使用 input_signature 參數來指定基於集合的輸入,例如 (List[Tensor], Tuple[Tensor, Tensor])。請參閱下面的第二個範例以獲取範例。您也可以指定引擎的運算精度或目標裝置等設定。編譯後,您可以像儲存任何其他模組一樣儲存模組,以便在部署應用程式中載入。若要載入 TensorRT/TorchScript 模組,請確保您先匯入 torch_tensorrt

import torch_tensorrt

...

model = MyModel().eval()  # torch module needs to be in eval (not training) mode

inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 1, 16, 16],
        opt_shape=[1, 1, 32, 32],
        max_shape=[1, 1, 64, 64],
        dtype=torch.half,
    )
]
enabled_precisions = {torch.float, torch.half}  # Run with fp16

trt_ts_module = torch_tensorrt.compile(
    model, inputs=inputs, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt

...

model = MyModel().eval()

# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
    [torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
    (torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}

trt_ts_module = torch_tensorrt.compile(
    model, input_signature=input_signature, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Deployment application
import torch
import torch_tensorrt

trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)

Torch-TensorRT Python API 也提供了 torch_tensorrt.ts.compile,它接受 TorchScript 模組作為輸入,以及 torch_tensorrt.fx.compile,它接受 FX GraphModule 作為輸入。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源