注意
前往結尾以下載完整的程式碼範例
使用 dynamo 後端編譯 SAM2¶
此範例說明使用 Torch-TensorRT 優化的最先進模型 Segment Anything Model 2 (SAM2)。
Segment Anything Model 2 是一個基礎模型,旨在解決圖像和影片中可提示的視覺分割問題。在編譯前,請先安裝以下相依性套件。
pip install -r requirements.txt
需要進行某些客製化修改,以確保模型成功匯出。要套用這些變更,請使用以下的分支安裝 SAM2 (安裝說明)
在客製化的 SAM2 分支中,已套用以下修改以移除圖形斷點並增強延遲效能,確保更有效率的 Torch-TRT 轉換。
一致的資料類型: 保留輸入張量的資料類型,移除強制的 FP32 轉換。
遮罩運算: 使用基於遮罩的索引,而不是直接選擇資料,從而提高 Torch-TRT 的相容性。
安全初始化: 有條件地初始化張量,而不是串連到空的張量。
標準函數: 避免特殊上下文和客製化的 LayerNorm,依靠內建的 PyTorch 函數來提高穩定性。
匯入以下函式庫¶
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_tensorrt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam_components import SAM2FullModel
matplotlib.use("Agg")
定義 SAM2 模型¶
使用 SAM2ImagePredictor
類別載入 facebook/sam2-hiera-large
預訓練模型。SAM2ImagePredictor
提供了實用工具來預處理圖像、儲存圖像特徵 (透過 set_image
函數) 並預測遮罩 (透過 predict
函數)。
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
為了確保我們成功匯出整個模型 (圖像編碼器和遮罩預測器) 組件,我們建立了一個獨立的模組 SAM2FullModel
,它使用 SAM2ImagePredictor
類別中的這些實用工具。SAM2FullModel
在單一步驟中執行特徵提取和遮罩預測,而不是 SAM2ImagePredictor
的兩步驟流程 (set_image 和 predict 函數)。
class SAM2FullModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.image_encoder = model.forward_image
self._prepare_backbone_features = model._prepare_backbone_features
self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
self.no_mem_embed = model.no_mem_embed
self._features = None
self.prompt_encoder = model.sam_prompt_encoder
self.mask_decoder = model.sam_mask_decoder
self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]
def forward(self, image, point_coords, point_labels):
backbone_out = self.image_encoder(image)
_, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)
if self.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
high_res_features = [
feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
]
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=(point_coords, point_labels), boxes=None, masks=None
)
low_res_masks, iou_predictions, _, _ = self.mask_decoder(
image_embeddings=features["image_embed"][-1].unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=point_coords.shape[0] > 1,
high_res_features=high_res_features,
)
out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
return out
使用預訓練權重初始化 SAM2 模型¶
使用預訓練權重初始化 SAM2FullModel
。由於我們已經初始化了 SAM2ImagePredictor
,我們可以從它 (predictor.model
) 直接使用該模型。我們將模型轉換為 FP16 精度以獲得更快的效能。
encoder = predictor.model.eval().cuda()
sam_model = SAM2FullModel(encoder.half()).eval().cuda()
載入儲存庫中提供的範例圖像。
input_image = Image.open("./truck.jpg").convert("RGB")
載入輸入圖像¶
這是我們將要使用的輸入圖像

input_image = Image.open("./truck.jpg").convert("RGB")
除了輸入圖像之外,我們還提供提示作為輸入,用於預測遮罩。提示可以是框、點以及來自先前預測迭代的遮罩。 在此示範中,我們使用一個點作為提示,類似於 SAM2 儲存庫中的原始筆記本
預處理組件¶
以下函數實現了預處理組件,這些組件對輸入圖像應用轉換並轉換給定的點坐標。 我們使用透過 SAM2ImagePredictor 類別提供的 SAM2Transforms。若要深入了解轉換,請參閱 https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py
def preprocess_inputs(image, predictor):
w, h = image.size
orig_hw = [(h, w)]
input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")
point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")
point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=predictor.device
)
unnorm_coords = predictor._transforms.transform_coords(
point_coords, normalize=True, orig_hw=orig_hw[0]
)
labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
if len(unnorm_coords.shape) == 2:
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
input_image = input_image.half()
unnorm_coords = unnorm_coords.half()
return (input_image, unnorm_coords, labels)
後處理組件¶
以下函數實現了後處理組件,其中包括繪製和視覺化遮罩和點。 我們使用 SAM2Transforms 來後處理這些遮罩,並透過信賴度分數對其進行排序。
def postprocess_masks(out, predictor, image):
"""Postprocess low-resolution masks and convert them for visualization."""
orig_hw = (image.size[1], image.size[0]) # (height, width)
masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
masks = (masks > 0.0).squeeze(0).cpu().numpy()
scores = out["iou_predictions"].squeeze(0).cpu().numpy()
sorted_indices = np.argsort(scores)[::-1]
return masks[sorted_indices], scores[sorted_indices]
def show_mask(mask, ax, random_color=False, borders=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [
cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
]
mask_image = cv2.drawContours(
mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def visualize_masks(
image, masks, scores, point_coords, point_labels, title_prefix="", save=True
):
"""Visualize and save masks overlaid on the original image."""
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(point_coords, point_labels, plt.gca())
plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
plt.close()
預處理輸入¶
預處理輸入。在以下程式碼片段中,torchtrt_inputs
包含 (input_image、unnormalized_coordinates 和 labels)。 unnormalized_coordinates 是點的表示,而 label (= 在此演示中為 1) 代表前景點。
torchtrt_inputs = preprocess_inputs(input_image, predictor)
Torch-TensorRT 編譯¶
以非嚴格模式匯出模型,並以 FP16 精度執行 Torch-TensorRT 編譯。我們使用 use_fp32_acc=True
啟用 FP32 matmul 累積,以保持與原始 Pytorch 模型相同的準確度。
exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
trt_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs=torchtrt_inputs,
min_block_size=1,
enabled_precisions={torch.float16},
use_fp32_acc=True,
)
trt_out = trt_model(*torchtrt_inputs)
輸出視覺化¶
後處理 Torch-TensorRT 的輸出,並使用上面提供的後處理組件視覺化遮罩。 輸出應儲存在您目前的目錄中。
trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
visualize_masks(
input_image,
trt_masks,
trt_scores,
torch.tensor([[500, 375]]),
torch.tensor([1]),
title_prefix="Torch-TRT",
)
參考資料¶
腳本的總執行時間:(0 分鐘 0.000 秒)