• 文件 >
  • Mutable Torch TensorRT 模組
快捷鍵

Mutable Torch TensorRT 模組

我們將示範如何輕鬆使用 Mutable Torch TensorRT 模組來編譯、互動和修改 TensorRT 圖形模組。

編譯 Torch-TensorRT 模組很簡單,但修改編譯後的模組可能具有挑戰性,尤其是在維護 PyTorch 模組和對應的 Torch-TensorRT 模組之間的狀態和連線時。在預先 (AoT) 情境中,將 Torch TensorRT 與複雜的管線(例如 Hugging Face Stable Diffusion 管線)整合變得更加困難。Mutable Torch TensorRT 模組旨在解決這些挑戰,使與 Torch-TensorRT 模組的互動比以往更容易。

在本教學中,我們將逐步介紹 1. 具有 ResNet 18 的 Mutable Torch TensorRT 模組的範例工作流程 2. 儲存 Mutable Torch TensorRT 模組 3. 在 LoRA 用例中與 Huggingface 管線整合

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models

np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

使用設定初始化 Mutable Torch TensorRT 模組。

settings = {
    "use_python": False,
    "enabled_precisions": {torch.float32},
    "immutable_weights": False,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)

對可變模組進行修改。

對可變模組進行變更可能會觸發重新擬合或重新編譯。例如,載入不同的 state_dict 並設定新的權重值將觸發重新擬合,而將模組新增至模型將觸發重新編譯。

model2 = models.resnet18(pretrained=False).eval().to("cuda")
mutable_module.load_state_dict(model2.state_dict())


# Check the output
# The refit happens while you call the mutable module again.
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
    assert torch.allclose(
        expected_output, refitted_output, 1e-2, 1e-2
    ), "Refit Result is not correct. Refit failed"

print("Refit successfully!")

儲存 Mutable Torch TensorRT 模組

# Currently, saving is only enabled for C++ runtime, not python runtime.
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

Stable Diffusion 與 Huggingface

# The LoRA checkpoint is from https://civitai.com/models/12597/moxin

from diffusers import DiffusionPipeline

with torch.no_grad():
    settings = {
        "use_python_runtime": True,
        "enabled_precisions": {torch.float16},
        "debug": True,
        "immutable_weights": False,
    }

    model_id = "runwayml/stable-diffusion-v1-5"
    device = "cuda:0"

    prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
    negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"

    pipe = DiffusionPipeline.from_pretrained(
        model_id, revision="fp16", torch_dtype=torch.float16
    )
    pipe.to(device)

    # The only extra line you need
    pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

    image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
    image.save("./without_LoRA_mutable.jpg")

    # Standard Huggingface LoRA loading procedure
    pipe.load_lora_weights(
        "stablediffusionapi/load_lora_embeddings",
        weight_name="moxin.safetensors",
        adapter_name="lora1",
    )
    pipe.set_adapters(["lora1"], adapter_weights=[1])
    pipe.fuse_lora()
    pipe.unload_lora_weights()

    # Refit triggered
    image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
    image.save("./with_LoRA_mutable.jpg")

腳本的總執行時間: (0 分鐘 0.000 秒)

由 Sphinx-Gallery 產生的圖庫

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源