捷徑

了解基於 TorchDynamo 的 ONNX 匯出器記憶體用量

先前的基於 TorchScript 的 ONNX 匯出器會執行模型一次以追蹤其執行,如果模型的記憶體需求超過可用的 GPU 記憶體,可能會導致 GPU 記憶體不足。新的基於 TorchDynamo 的 ONNX 匯出器已解決此問題。

基於 TorchDynamo 的 ONNX 匯出器利用 FakeTensorMode,以避免在匯出過程中執行實際的張量計算。與基於 TorchScript 的 ONNX 匯出器相比,此方法可顯著降低記憶體用量。

以下範例展示了基於 TorchScript 和基於 TorchDynamo 的 ONNX 匯出器之間的記憶體用量差異。在此範例中,我們使用來自 MONAI 的 HighResNet 模型。在繼續之前,請從 PyPI 安裝它

pip install monai

PyTorch 提供一個工具,用於擷取和視覺化記憶體用量追蹤。我們將使用此工具記錄兩個匯出器在匯出過程中的記憶體用量,並比較結果。您可以在了解 CUDA 記憶體用量上找到有關此工具的更多詳細資訊。

基於 TorchScript 的匯出器

可以執行以下程式碼來產生快照檔案,該檔案記錄匯出過程中已配置 CUDA 記憶體的狀態。

import torch

from torch.onnx.utils import export
from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    export(
        model,
        data,
        "torchscript_exporter_highresnet.onnx",
    )

snapshot_name = f"torchscript_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")

開啟 pytorch.org/memory_viz 並將產生的 pickled 快照檔案拖放到可視化工具中。記憶體用量如下所述

_images/torch_script_exporter_memory_usage.png

透過此圖,我們可以看到記憶體用量峰值超過 2.8GB。

基於 TorchDynamo 的匯出器

可以執行以下程式碼來產生快照檔案,該檔案記錄匯出過程中已配置 CUDA 記憶體的狀態。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
                        model,
                        data,
                        "test_faketensor.onnx",
                        dynamo=True,
                    )

snapshot_name = f"torchdynamo_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")

開啟 pytorch.org/memory_viz 並將產生的 pickled 快照檔案拖放到可視化工具中。記憶體用量如下所述

_images/torch_dynamo_exporter_memory_usage.png

透過此圖,我們可以看到記憶體用量峰值僅約為 45MB。與基於 TorchScript 的匯出器的記憶體用量峰值相比,它減少了 98% 的記憶體用量。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

為初學者和進階開發者提供深入教學

查看教學

資源

尋找開發資源並獲得解答

查看資源