• 文件 >
  • 直接從 PyTorch 使用 Torch-TensorRT TorchScript 前端
捷徑

直接從 PyTorch 使用 Torch-TensorRT TorchScript 前端

您現在將能夠直接從 PyTorch API 存取 TensorRT。使用此功能的流程與 在 Python 中使用 Torch-TensorRT 中描述的編譯工作流程非常相似。

首先將 torch_tensorrt 載入您的應用程式。

import torch
import torch_tensorrt

然後,給定一個 TorchScript 模組,您可以使用 torch._C._jit_to_backend("tensorrt", ...) API 使用 TensorRT 編譯它。

import torchvision.models as models

model = models.mobilenet_v2(pretrained=True)
script_model = torch.jit.script(model)

與 Torch-TensorRT 中的 compile API 不同,後者假設您嘗試編譯模組的 forward 函數,或是 convert_method_to_trt_engine,後者會將指定的函數轉換為 TensorRT 引擎,後端 API 將採用一個字典,該字典將要編譯的函數名稱映射到 Compilation Spec 物件,這些物件包裝了您提供給 compile 的相同類型的字典。有關編譯規格字典的更多資訊,請參閱 Torch-TensorRT TensorRTCompileSpec API 的文檔。

spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
        **{
            "inputs": [torch_tensorrt.Input([1, 3, 300, 300])],
            "enabled_precisions": {torch.float, torch.half},
            "refit": False,
            "debug": False,
            "device": {
                "device_type": torch_tensorrt.DeviceType.GPU,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            "capability": torch_tensorrt.EngineCapability.default,
            "num_avg_timing_iters": 1,
        }
    )
}

現在要使用 Torch-TensorRT 進行編譯,請將目標模組物件和規格字典提供給 torch._C._jit_to_backend("tensorrt", ...)

trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)

若要執行,請顯式呼叫要運行的函數方法 (與在標準 PyTorch 中直接呼叫模組本身的方式不同)

input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))

文件

取得關於 PyTorch 的全面開發者文檔

查看文檔

教學

取得針對初學者和高級開發者的深入教學課程

查看教學課程

資源

尋找開發資源並取得問題解答

查看資源