• 文件 >
  • 使用 dynamo 後端編譯 SAM2
捷徑

使用 dynamo 後端編譯 SAM2

此範例說明使用 Torch-TensorRT 優化的最先進模型 Segment Anything Model 2 (SAM2)

Segment Anything Model 2 是一個基礎模型,旨在解決圖像和影片中可提示的視覺分割問題。在編譯前,請先安裝以下相依性套件。

pip install -r requirements.txt

需要進行某些客製化修改,以確保模型成功匯出。要套用這些變更,請使用以下的分支安裝 SAM2 (安裝說明)

在客製化的 SAM2 分支中,已套用以下修改以移除圖形斷點並增強延遲效能,確保更有效率的 Torch-TRT 轉換。

  • 一致的資料類型: 保留輸入張量的資料類型,移除強制的 FP32 轉換。

  • 遮罩運算: 使用基於遮罩的索引,而不是直接選擇資料,從而提高 Torch-TRT 的相容性。

  • 安全初始化: 有條件地初始化張量,而不是串連到空的張量。

  • 標準函數: 避免特殊上下文和客製化的 LayerNorm,依靠內建的 PyTorch 函數來提高穩定性。

匯入以下函式庫

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_tensorrt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam_components import SAM2FullModel

matplotlib.use("Agg")

定義 SAM2 模型

使用 SAM2ImagePredictor 類別載入 facebook/sam2-hiera-large 預訓練模型。SAM2ImagePredictor 提供了實用工具來預處理圖像、儲存圖像特徵 (透過 set_image 函數) 並預測遮罩 (透過 predict 函數)。

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

為了確保我們成功匯出整個模型 (圖像編碼器和遮罩預測器) 組件,我們建立了一個獨立的模組 SAM2FullModel,它使用 SAM2ImagePredictor 類別中的這些實用工具。SAM2FullModel 在單一步驟中執行特徵提取和遮罩預測,而不是 SAM2ImagePredictor 的兩步驟流程 (set_image 和 predict 函數)。

class SAM2FullModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.image_encoder = model.forward_image
        self._prepare_backbone_features = model._prepare_backbone_features
        self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
        self.no_mem_embed = model.no_mem_embed
        self._features = None

        self.prompt_encoder = model.sam_prompt_encoder
        self.mask_decoder = model.sam_mask_decoder

        self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]

    def forward(self, image, point_coords, point_labels):
        backbone_out = self.image_encoder(image)
        _, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)

        if self.directly_add_no_mem_embed:
            vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).view(1, -1, *feat_size)
            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
        ][::-1]
        features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}

        high_res_features = [
            feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
        ]

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=(point_coords, point_labels), boxes=None, masks=None
        )

        low_res_masks, iou_predictions, _, _ = self.mask_decoder(
            image_embeddings=features["image_embed"][-1].unsqueeze(0),
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=point_coords.shape[0] > 1,
            high_res_features=high_res_features,
        )

        out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
        return out

使用預訓練權重初始化 SAM2 模型

使用預訓練權重初始化 SAM2FullModel。由於我們已經初始化了 SAM2ImagePredictor,我們可以從它 (predictor.model) 直接使用該模型。我們將模型轉換為 FP16 精度以獲得更快的效能。

encoder = predictor.model.eval().cuda()
sam_model = SAM2FullModel(encoder.half()).eval().cuda()

載入儲存庫中提供的範例圖像。

input_image = Image.open("./truck.jpg").convert("RGB")

載入輸入圖像

這是我們將要使用的輸入圖像

../../../_images/truck.jpg
input_image = Image.open("./truck.jpg").convert("RGB")

除了輸入圖像之外,我們還提供提示作為輸入,用於預測遮罩。提示可以是框、點以及來自先前預測迭代的遮罩。 在此示範中,我們使用一個點作為提示,類似於 SAM2 儲存庫中的原始筆記本

預處理組件

以下函數實現了預處理組件,這些組件對輸入圖像應用轉換並轉換給定的點坐標。 我們使用透過 SAM2ImagePredictor 類別提供的 SAM2Transforms。若要深入了解轉換,請參閱 https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py

def preprocess_inputs(image, predictor):
    w, h = image.size
    orig_hw = [(h, w)]
    input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")

    point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
    point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")

    point_coords = torch.as_tensor(
        point_coords, dtype=torch.float, device=predictor.device
    )
    unnorm_coords = predictor._transforms.transform_coords(
        point_coords, normalize=True, orig_hw=orig_hw[0]
    )
    labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
    if len(unnorm_coords.shape) == 2:
        unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]

    input_image = input_image.half()
    unnorm_coords = unnorm_coords.half()

    return (input_image, unnorm_coords, labels)

後處理組件

以下函數實現了後處理組件,其中包括繪製和視覺化遮罩和點。 我們使用 SAM2Transforms 來後處理這些遮罩,並透過信賴度分數對其進行排序。

def postprocess_masks(out, predictor, image):
    """Postprocess low-resolution masks and convert them for visualization."""
    orig_hw = (image.size[1], image.size[0])  # (height, width)
    masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
    masks = (masks > 0.0).squeeze(0).cpu().numpy()
    scores = out["iou_predictions"].squeeze(0).cpu().numpy()
    sorted_indices = np.argsort(scores)[::-1]
    return masks[sorted_indices], scores[sorted_indices]


def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [
            cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
        ]
        mask_image = cv2.drawContours(
            mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
        )
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def visualize_masks(
    image, masks, scores, point_coords, point_labels, title_prefix="", save=True
):
    """Visualize and save masks overlaid on the original image."""
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(point_coords, point_labels, plt.gca())
        plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis("off")
        plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
        plt.close()

預處理輸入

預處理輸入。在以下程式碼片段中,torchtrt_inputs 包含 (input_image、unnormalized_coordinates 和 labels)。 unnormalized_coordinates 是點的表示,而 label (= 在此演示中為 1) 代表前景點。

torchtrt_inputs = preprocess_inputs(input_image, predictor)

Torch-TensorRT 編譯

以非嚴格模式匯出模型,並以 FP16 精度執行 Torch-TensorRT 編譯。我們使用 use_fp32_acc=True 啟用 FP32 matmul 累積,以保持與原始 Pytorch 模型相同的準確度。

exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
trt_model = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=torchtrt_inputs,
    min_block_size=1,
    enabled_precisions={torch.float16},
    use_fp32_acc=True,
)
trt_out = trt_model(*torchtrt_inputs)

輸出視覺化

後處理 Torch-TensorRT 的輸出,並使用上面提供的後處理組件視覺化遮罩。 輸出應儲存在您目前的目錄中。

trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
visualize_masks(
    input_image,
    trt_masks,
    trt_scores,
    torch.tensor([[500, 375]]),
    torch.tensor([1]),
    title_prefix="Torch-TRT",
)
預測的遮罩如下所示
../../../_images/sam_mask1.png ../../../_images/sam_mask2.png ../../../_images/sam_mask3.png

參考資料

腳本的總執行時間:(0 分鐘 0.000 秒)

由 Sphinx-Gallery 產生的圖庫

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源