• 文件 >
  • 使用 Torch-TensorRT 的動態形狀
捷徑

使用 Torch-TensorRT 的動態形狀

預設情況下,您可以使用不同的輸入形狀執行 PyTorch 模型,並且輸出形狀會被立即確定。然而,Torch-TensorRT 是一個 AOT 編譯器,它需要一些關於輸入形狀的先驗資訊才能編譯和最佳化模型。

使用 torch.export (AOT) 的動態形狀

在動態輸入形狀的情況下,我們必須提供 (min_shape, opt_shape, max_shape) 參數,以便可以針對此輸入形狀範圍最佳化模型。以下是靜態和動態形狀的使用範例。

注意:以下程式碼使用 Dynamo 前端。如果是 Torchscript 前端,請將 ir=dynamo 替換為 ir=ts,行為完全相同。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
# Compile with static shapes
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)
# or compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
                              opt_shape=[4, 3, 224, 224],
                              max_shape=[8, 3, 224, 224],
                              dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)

底層原理

當我們使用 torch_tensorrt.compile API 搭配 ir=dynamo (預設) 時,編譯過程分為兩個階段。

  • torch_tensorrt.dynamo.trace (使用 torch.export 以給定的輸入追蹤圖)

我們使用 torch.export.export() API 來追蹤並將 PyTorch 模組匯出成 torch.export.ExportedProgram。在動態形狀輸入的情況下,透過 torch_tensorrt.Input API 提供的 (min_shape, opt_shape, max_shape) 範圍會被用來建構 torch.export.Dim 物件,這些物件會用於匯出 API 的 dynamic_shapes 參數中。請查看 _tracer.py 檔案,以了解其底層運作方式。

  • torch_tensorrt.dynamo.compile (使用 TensorRT 編譯 torch.export.ExportedProgram 物件)

在轉換為 TensorRT 的過程中,圖的節點元數據中已包含動態形狀資訊,這些資訊將在引擎構建階段中使用。

自定義動態形狀約束

給定一個輸入 x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype),Torch-TensorRT 嘗試在 torch.export 追蹤期間,通過相應地構造帶有提供的動態維度的 torch.export.Dim 物件來自動設定約束。 有時,我們可能需要設定額外的約束,如果我們不指定這些約束,Torchdynamo 會出錯。 如果您必須為模型設定任何自定義約束 (通過使用 torch.export.Dim),我們建議您在使用 Torch-TensorRT 編譯之前先匯出您的程式。 請參閱此文檔,以匯出具有動態形狀的 Pytorch 模組。 這是一個簡單的例子,匯出了一個對動態維度有一些限制的 matmul 層。

import torch
import torch_tensorrt

class MatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key):
        attn_weight = torch.matmul(query, key.transpose(-1, -2))
        return attn_weight

model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs)
# Run inference
trt_gm(*inputs)

使用 torch.compile (JIT) 的動態形狀

torch_tensorrt.compile(model, inputs, ir="torch_compile") 返回一個 torch.compile box 函數,後端配置為 TensorRT。在 ir=torch_compile 的情況下,用戶可以使用 torch._dynamo.mark_dynamic API 提供輸入的動態形狀資訊(https://pytorch.dev.org.tw/docs/stable/torch.compiler_dynamic_shapes.html)來避免重新編譯 TensorRT 引擎。

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
# This indicates the dimension 0 is dynamic and the range is [1, 8]
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
trt_gm = torch.compile(model, backend="tensorrt")
# Compilation happens when you call the model
trt_gm(inputs)

# No recompilation of TRT engines with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm(inputs_bs2)

文件

訪問 PyTorch 的全面開發者文檔

查看文檔

教學

獲取針對初學者和高級開發人員的深入教程

查看教程

資源

查找開發資源並獲得問題的解答

查看資源