• 文件 >
  • 使用 torch.compile 後端編譯 BERT
快捷鍵

使用 torch.compile 後端編譯 BERT

此互動式腳本旨在作為 BERT 模型上使用 torch.compile 的 Torch-TensorRT 工作流程範例。

匯入和模型定義

import torch
import torch_tensorrt
from transformers import BertModel
# Initialize model with float precision and sample inputs
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
inputs = [
    torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]

用於 torch_tensorrt.compile 的選用輸入引數

# Enabled precision for TensorRT optimization
enabled_precisions = {torch.float}

# Whether to print verbose logs
debug = True

# Workspace size for TensorRT
workspace_size = 20 << 30

# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 7

# Operations to Run in Torch, regardless of converter support
torch_executed_ops = {}

使用 torch.compile 編譯

# Define backend compilation keyword arguments
compilation_kwargs = {
    "enabled_precisions": enabled_precisions,
    "debug": debug,
    "workspace_size": workspace_size,
    "min_block_size": min_block_size,
    "torch_executed_ops": torch_executed_ops,
}

# Build and compile the model with torch.compile, using Torch-TensorRT backend
optimized_model = torch.compile(
    model,
    backend="torch_tensorrt",
    dynamic=False,
    options=compilation_kwargs,
)
optimized_model(*inputs)

或者,我們可以透過便利前端執行上述操作,如下所示:torch_tensorrt.compile(model, ir=”torch_compile”, inputs=inputs, **compilation_kwargs)

推論

# Does not cause recompilation (same batch size as input)
new_inputs = [
    torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)
# Does cause recompilation (new batch size)
new_inputs = [
    torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
    torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)

清除

# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()

Cuda 驅動程式錯誤注意事項

偶爾,在使用 torch_tensorrt 進行 Dynamo 編譯後退出 Python 執行階段時,可能會遇到 Cuda 驅動程式錯誤。此問題與 https://github.com/NVIDIA/TensorRT/issues/2052 相關,可以透過將編譯/推論包裝在函式中並使用作用域呼叫來解決,如下所示

if __name__ == '__main__':
    compile_engine_and_infer()

腳本總執行時間: ( 0 分鐘 0.000 秒)

由 Sphinx-Gallery 產生圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源