快捷方式

(beta) 使用 torch.compile 編譯優化器

建立於:2024 年 1 月 24 日 | 最後更新:2024 年 1 月 29 日 | 最後驗證:2024 年 11 月 05 日

作者: Michael Lazos

優化器是訓練任何深度學習模型的關鍵演算法。由於它負責更新每個模型參數,因此對於大型模型而言,它通常會成為訓練效能的瓶頸。在本食譜中,我們將應用 torch.compile 到優化器,以觀察 GPU 效能的提升。

注意

本教學需要 PyTorch 2.2.0 或更新版本。

模型設定

在此範例中,我們將使用一系列簡單的線性層。由於我們僅對優化器進行基準測試,因此模型的選擇並不重要,因為優化器效能是參數數量的函數。

根據您使用的機器,您的確切結果可能會有所不同。

import torch

model = torch.nn.Sequential(
    *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
output = model(input)
output.sum().backward()

設定和執行優化器基準測試

在此範例中,我們將使用 Adam 優化器,並建立一個輔助函數來將 step() 包裹在 torch.compile() 中。

注意

torch.compile 僅支援計算能力 >= 7.0 的 cuda 裝置

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys
    sys.exit(0)


opt = torch.optim.Adam(model.parameters(), lr=0.01)


@torch.compile(fullgraph=False)
def fn():
    opt.step()


# Let's define a helpful benchmarking function:
import torch.utils.benchmark as benchmark


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


# Warmup runs to compile the function
for _ in range(5):
    fn()

eager_runtime = benchmark_torch_function_in_microseconds(opt.step)
compiled_runtime = benchmark_torch_function_in_microseconds(fn)

assert eager_runtime > compiled_runtime

print(f"eager runtime: {eager_runtime}us")
print(f"compiled runtime: {compiled_runtime}us")

範例結果

  • Eager 執行時間:747.2437149845064us

  • 編譯執行時間:392.07384741178us

另請參閱

  • 如需深入的技術概述,請參閱

使用 PT2 編譯優化器

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源