torch_tensorrt.fx¶
函式¶
- torch_tensorrt.fx.compile(module: Module, input, min_acc_module_size: int = 10, max_batch_size: int = 2048, max_workspace_size=33554432, explicit_batch_dimension=False, lower_precision=LowerPrecision.FP16, verbose_log=False, timing_cache_prefix='', save_timing_cache=False, cuda_graph_batch_size=- 1, dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, correctness_atol=0.1, correctness_rtol=0.1) Module [source]¶
接收原始模組、輸入和 lowering 設定,執行 lowering 工作流程以將模組轉換為 lowered 模組,或稱為 TRTModule。
- 參數
module – 用於 lowering 的原始模組。
input – 模組的輸入。
max_batch_size – 最大批次大小 (必須 >= 1 才能設定,0 代表未設定)
min_acc_module_size – 加速子模組的最小節點數
max_workspace_size – 給予 TensorRT 的最大工作區大小。
explicit_batch_dimension – 若設為 True,則在 TensorRT 中使用明確的批次維度,否則使用隱含的批次維度。
lower_precision – 給予 TRTModule 的 lower_precision 設定。
verbose_log – 若設為 True,則啟用 TensorRT 的詳細記錄。
timing_cache_prefix – fx2trt 使用的 timing cache 檔名。
save_timing_cache – 若設為 True,則使用目前的 timing cache 資料更新 timing cache。
cuda_graph_batch_size – Cuda graph 批次大小,預設為 -1。
dynamic_batch – 批次維度 (dim=0) 為動態。
use_experimental_fx_rt – 使用下一代 TRTModule,其同時支援基於 Python 和 TorchScript 的執行 (包括在 C++ 中)。
- 傳回
由 TensorRT lowered 的 torch.nn.Module。
類別¶
- class torch_tensorrt.fx.TRTModule(engine=None, input_names=None, output_names=None, cuda_graph_batch_size=- 1)[source]¶
- class torch_tensorrt.fx.InputTensorSpec(shape: Sequence[int], dtype: dtype, device: device = device(type='cpu'), shape_ranges: List[Tuple[Sequence[int], Sequence[int], Sequence[int]]] = [], has_batch_dim: bool = True)[source]¶
此類別包含輸入張量的資訊。
shape: 張量的形狀。
dtype: 張量的資料類型。
- device: 張量的裝置。這僅用於產生給定模型的輸入
以便執行形狀傳播。對於 TensorRT 引擎,輸入必須在 CUDA 裝置上。
- shape_ranges: 如果需要動態形狀 (形狀的維度為 -1),則此欄位
必須提供 (預設為空清單)。每個 shape_range 都是三個元組的元組 ((min_input_shape)、(optimized_input_shape)、(max_input_shape))。每個 shape_range 用於填入 TensorRT 優化設定檔。例如,如果輸入形狀從 (1, 224) 變更為 (100, 224),且我們想要針對 (25, 224) 進行優化,因為這是最常見的輸入形狀,則我們將 shape_ranges 設定為 ((1, 224)、(25, 225)、(100, 224))。
- has_batch_dim: 形狀是否包含批次維度。如果引擎想要以動態形狀執行,則必須提供批次維度。
if the engine want to run with dynamic shape.
- class torch_tensorrt.fx.TRTInterpreter(module: GraphModule, input_specs: List[InputTensorSpec], explicit_batch_dimension: bool = False, explicit_precision: bool = False, logger_level=None)[source]¶