捷徑

torch.export 流程、常見挑戰以及解決方案的示範

作者: Ankith GunapalJordi RamonMarcos Carranza

torch.export 教學簡介 中,我們學習了如何使用 torch.export。本教學是前一個教學的擴展,探討了使用程式碼匯出熱門模型的過程,並解決了使用 torch.export 可能會出現的常見挑戰。

在本教學中,您將學習如何針對以下使用案例匯出模型

選擇這四個模型中的每一個,都是為了示範 torch.export 的獨特功能,以及實作中面臨的一些實際考量和問題。

先決條件

  • PyTorch 2.4 或更新版本

  • torch.export 和 PyTorch Eager 推論有基本了解。

torch.export 的關鍵需求:沒有圖形中斷

torch.compile 透過使用 JIT 將 PyTorch 程式碼編譯為最佳化的核心,來加速 PyTorch 程式碼。它使用 TorchDynamo 最佳化給定的模型,並建立最佳化的圖形,然後使用 API 中指定的後端將其降低到硬體中。當 TorchDynamo 遇到不支援的 Python 功能時,它會中斷計算圖形,讓預設的 Python 解譯器處理不支援的程式碼,然後繼續捕獲圖形。計算圖形中的這種中斷稱為 圖形中斷

torch.exporttorch.compile 之間的一個主要差異是,torch.export 不支援圖形中斷,這表示您要匯出的整個模型或部分模型需要是單一圖形。這是因為處理圖形中斷涉及使用預設的 Python 評估來解譯不支援的操作,這與 torch.export 的設計目的不符。您可以在此連結中閱讀有關各種 PyTorch 框架之間差異的詳細資訊

您可以使用以下命令來識別程式中的圖形中斷

TORCH_LOGS="graph_breaks" python <file_name>.py

您需要修改程式以消除圖形中斷。解決後,您就可以匯出模型了。PyTorch 為熱門的 HuggingFace 和 TIMM 模型上的 torch.compile 執行夜間基準測試。這些模型中的大多數都沒有圖形中斷。

此配方中的模型沒有圖形中斷,但在使用 torch.export 時會失敗。

影片分類

MViT 是一類基於 MultiScale Vision Transformers 的模型。此模型已使用 Kinetics-400 數據集進行影片分類訓練。具有相關數據集的此模型可用於遊戲環境中的動作識別。

以下程式碼透過追蹤 batch_size=2 來匯出 MViT,然後檢查 ExportedProgram 是否能以 batch_size=4 執行。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb

model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
exported_program = torch.export.export(
    model,
    (input_frames,),
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

錯誤:靜態批次大小

    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4

預設情況下,匯出流程會追蹤程式,假設所有輸入形狀都是靜態的。因此,如果您以與追蹤時不同的輸入形狀執行程式,就會遇到錯誤。

解決方案

為了解決此錯誤,我們指定輸入的第一個維度(batch_size)為動態,並指定 batch_size 的預期範圍。在下面更正後的範例中,我們指定預期的 batch_size 可以從 1 到 16。需要注意的是,min=2 並非錯誤,其原因已在 0/1 特殊化問題 中說明。關於 torch.export 的動態形狀的詳細說明,請參閱匯出教學課程。下面的程式碼示範了如何匯出具有動態批次大小的 mViT。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb


model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)

# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
    model,
    (input_frames,),
    # Specify the first dimension of the input x as dynamic
    dynamic_shapes={"x": {0: batch_dim}},
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

自動語音辨識

自動語音辨識 (ASR) 是使用機器學習將口語轉錄成文字。 Whisper 是 OpenAI 基於 Transformer 的編碼器-解碼器模型,該模型在 68 萬小時的標記資料上進行了 ASR 和語音翻譯的訓練。下面的程式碼嘗試匯出用於 ASR 的 whisper-tiny 模型。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))

錯誤:使用 TorchDynamo 進行嚴格追蹤

torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'

預設情況下,torch.export 使用 TorchDynamo(一種位元組碼分析引擎)追蹤您的程式碼,該引擎以符號方式分析您的程式碼並建立圖表。此分析提供了更強的安全保證,但並非所有 Python 程式碼都支援。當我們使用預設的嚴格模式匯出 whisper-tiny 模型時,由於不支援的功能,它通常會在 Dynamo 中傳回錯誤。要了解為什麼這會在 Dynamo 中發生錯誤,您可以參考此 GitHub 問題

解決方案

為了解決上述錯誤,torch.export 支援 non_strict 模式,其中程式使用 Python 直譯器進行追蹤,這與 PyTorch 即時執行類似。唯一的區別是所有 Tensor 物件將被 ProxyTensors 取代,後者會將所有操作記錄到圖表中。透過使用 strict=False,我們可以匯出程式。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)

影像標題生成

影像標題生成 是用文字定義影像內容的任務。在遊戲的背景下,影像標題生成可用於透過動態生成場景中各種遊戲物件的文字描述來增強遊戲體驗,從而為遊戲玩家提供更多詳細資訊。 BLIP 是用於影像標題生成的流行模型,由 SalesForce Research 發布。下面的程式碼嘗試以 batch_size=1 匯出 BLIP。

import torch
from models.blip import blip_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)

錯誤:無法修改具有凍結儲存體的張量

匯出模型時,可能會失敗,因為模型實作可能包含某些 torch.export 尚未支援的 Python 操作。其中一些失敗可能有解決方法。 BLIP 是一個範例,其中原始模型會發生錯誤,可以透過對程式碼進行少量變更來解決。 torch.exportExportDB 中列出了支援和不支援操作的常見案例,並展示了如何修改您的程式碼以使其與匯出相容。

File "/BLIP/models/blip.py", line 112, in forward
    text.input_ids[:,0] = self.tokenizer.bos_token_id
  File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage

解決方案

複製 張量,其中匯出失敗。

text.input_ids = text.input_ids.clone() # clone the tensor
text.input_ids[:,0] = self.tokenizer.bos_token_id

注意

此限制已在 PyTorch 2.7 nightly 版本中放寬。這應該可以在 PyTorch 2.7 中直接使用。

可提示影像分割

影像分割 是一種電腦視覺技術,它根據像素的特徵將數位影像劃分為不同的像素組或區段。 Segment Anything Model (SAM) 引入了可提示影像分割,該分割根據指示所需物件的提示來預測物件遮罩。 SAM 2 是第一個用於分割影像和影片中物件的統一模型。 SAM2ImagePredictor 類別提供了一個簡單的介面來提示模型。該模型可以將點和框提示以及來自先前預測迭代的遮罩作為輸入。由於 SAM2 為物件追蹤提供強大的零樣本效能,因此可用於追蹤場景中的遊戲物件。

SAM2ImagePredictor 的 predict 方法中的張量運算發生在 _predict 方法中。所以,我們嘗試這樣匯出。

ep = torch.export.export(
    self._predict,
    args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
    kwargs={"return_logits": return_logits},
    strict=False,
)

錯誤:模型不是 torch.nn.Module 類型

torch.export 期望模組的類型為 torch.nn.Module。但是,我們嘗試匯出的模組是一個類別方法。因此它會發生錯誤。

Traceback (most recent call last):
  File "/sam2/image_predict.py", line 20, in <module>
    masks, scores, _ = predictor.predict(
  File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
    ep = torch.export.export(
  File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.

解決方案

我們編寫一個輔助類別,該類別繼承自 torch.nn.Module,並在該類別的 forward 方法中呼叫 _predict method。完整的程式碼可以在這裡找到。

class ExportHelper(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(_, *args, **kwargs):
        return self._predict(*args, **kwargs)

 model_to_export = ExportHelper()
 ep = torch.export.export(
      model_to_export,
      args=(unnorm_coords, labels, unnorm_box, mask_input,  multimask_output),
      kwargs={"return_logits": return_logits},
      strict=False,
      )

結論

在本教學課程中,我們學習了如何透過正確的配置和簡單的程式碼修改來解決挑戰,從而使用 torch.export 匯出熱門用例的模型。匯出模型後,如果您使用的是伺服器,可以使用 AOTInductorExportedProgram 降低到您的硬體中,如果使用的是邊緣裝置,則可以使用 ExecuTorch。若要了解有關 AOTInductor (AOTI) 的更多資訊,請參閱 AOTI 教學課程。若要了解有關 ExecuTorch 的更多資訊,請參閱 ExecuTorch 教學課程

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得適合初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源