• 文件 >
  • Eager 模式 + 編譯 API
快捷方式

Eager 模式 + 編譯 API

在這份文件中,我們將介紹如何使用 PyTorch/XLA 新的實驗性 eager 模式與 compile API。目標是讓 PyTorch/XLA 的體驗更貼近原生 PyTorch,並簡化開發流程。

目前 PyTorch/XLA 預設以 LazyTensor 追蹤模式執行。在以下程式碼中

import torch
import torch_xla
import torchvision

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)

# model tracing
res = model(input)

# model execution, same as `xm.mark_step`
torch_xla.sync()

實際的模型編譯和裝置執行發生在呼叫 torch_xla.sync 時。這種方法有多個缺點。

  1. 使用者常常對框架何時在追蹤以及何時在執行感到困惑。

  2. 非核心模型程式碼 (例如資料預處理) 常常產生一些小的待處理執行,這些執行會洩漏到主圖 (step 函數) 中並導致重新編譯。整個圖的重新編譯通常非常耗費資源。

  3. 很難除錯重新編譯何時/為何發生。

為了減輕上述問題,我們希望引入新的 UX,結合 eager 和 compile。

基本用法

import torch
import torch_xla
import torchvision

# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)

# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)

# Compilation and execution happens right away.
res = compiled_model(input)

請注意

  1. 目前使用者必須手動透過 torch_xla.experimental.eager_mode(True) 啟用 eager 模式。

  2. 想要編譯的程式碼區域應該用 torch_xla.compile 包裹起來。

torch_xla.compile 的實作實際上非常直接,它在進入目標函數時停用 eager 模式並開始追蹤。它會在目標函數返回時呼叫 torch_xla.sync() 並重新啟用 eager 模式。相較於現有的 mark_step/sync 方法,使用 eager + compile API 可以預期相同的效能。

推論

torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")

建議在推論時使用 torch.compile 而不是 torch_xla.compile,以減少追蹤的 overhead。

訓練

torch_xla.experimental.eager_mode(True)

def step_fn(model, data, target, loss_fn, optimizer):
    optimizer.zero_grad()
    logits = model(data)
    loss = loss_fn(logits, target)
    loss.backward()
    optimizer.step()
    return loss

step_fn = torch_xla.compile(step_fn)

在訓練中,我們要求使用者將 step_fn 重構出來,因為通常最好將模型的前向、反向和優化器一起編譯。長遠目標是也在訓練中使用 torch.compile,但目前我們建議使用者使用 torch_xla.compile (基於效能考量)。

基準測試

我使用虛擬資料在 v4-8 單晶片上執行了一個 2 層解碼器模型訓練 (基本上就是 llama2) 300 步。以下是我觀察到的數字。

模式 token/秒


追蹤模式 (基準線) 147 Eager 模式 65 Eager + torch_xla compile 147

: Eager 模式基準測試

對於僅解碼器模型,Eager 模式可以達到完全編譯模型約 45% 的效能。如需更多資訊,請參閱 train_decoder_only_base.pyeager example。請注意,eager 模式的效能非常依賴模型。當我嘗試執行 resnet50 時,eager 模式的效能約為編譯模式的 1%。我們不期望使用者使用 eager 模式來執行主要訓練迴圈。Eager 模式旨在用於處理訓練/推論邏輯的非核心部分 (資料預處理、隨機數生成等) 或除錯。

文件

存取全面的 PyTorch 開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源