• 文件 >
  • TorchDynamo 整合至 PyTorch XLA
捷徑

TorchDynamo 整合至 PyTorch XLA

TorchDynamo 是一個 Python 層級的 JIT 編譯器,旨在使未修改的 PyTorch 程式執行得更快。它為編譯器後端提供了一個清晰的 API 以進行掛鉤,其最大的特色是在 Python 位元組碼執行之前動態修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 為 TorchDynamo 提供了實驗性的後端,用於推論和訓練。

XLA bridge 的運作方式是,當 Dynamo 識別出模型模式時,它將提供一個 TorchFX 圖,而 PyTorch/XLA 將使用現有的 Lazy Tensor 技術來編譯 FX 圖並傳回已編譯的函式。

整合

目前透過將 backend='openxla' 引數新增至 torch.compile,即可支援 PyTorch/XLA 和 Dynamo。例如:

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))

推論

以下是以 torch.compile 執行 resnet18 的小程式碼範例:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
    xla_resnet18, backend='openxla')
  for data, _ in loader:
    with torch.no_grad():
      output = dynamo_resnet18(data)

使用 torch.compile,您將看到 PyTorch/XLA 只會在初始化時追蹤 resent18 模型一次,並在每次調用 dynamo_resnet18 時執行已編譯的二進位檔,而不是每次都追蹤模型。以下是在 Cloud TPU v4-8 上使用 torch bench 比較 Dynamo 和 Lazy 的推論速度分析:

模型 加速
resnet18 2.59
resnet50 2.64
resnext50_32x4d 1.91
alexnet 1.28
mobilenet_v2 18.62
mnasnet1_0 2.68
vgg16 1.33
BERT_pytorch 7.49
squeezenet1_1 2.29
timm_vision_transformer 3.52
幾何平均數 3.04

訓練

PyTorch/XLA 也支援 Dynamo 進行訓練,但它仍處於實驗階段,我們正與 PyTorch Compiler 團隊合作迭代實作。以下是以 torch.compile 訓練 resnet18 的範例:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target, optimizer):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  optimizer.step()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='openxla')
  for data, target in loader:
    xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
    output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)

如果您使用 Lazy tensor,我們預期每個訓練步驟會提取和執行 3 個圖,而不是每個訓練步驟 1 個圖。以下是在 Cloud TPU v4-8 上使用 torch bench 比較 Dynamo 和 Lazy 的訓練速度分析。

模型 加速
resnet50 1.33
resnet18 1.33
BERT_pytorch 3.07
resnext50_32x4d 1.43
alexnet 1.12
mobilenet_v2 1.4
mnasnet1_0 1.19
vgg16 0.81
timm_vision_transformer 1.87
squeezenet1_1 1.41
幾何平均數 1.41

注意: 我們針對每個模型的 fwd 和 bwd 執行單一步驟,然後收集端對端時間。在現實世界中,我們將在每個訓練工作中執行多個步驟,這可以輕鬆地隱藏來自執行的追蹤成本 (因為它是非同步的)。在這種情況下,Lazy Tensor 將具有更好的效能。

功能差距

我們想指出一個差距,這個差距阻礙了我們在更大規模的模型上使用 TorchDynamo。

TorchDynamo 將向前和向後追蹤到不同的圖中。對於 PyTorch/XLA 來說,讓 XLA 編譯器將整個步驟視為一個圖,以最佳化速度非常重要。啟動每個裝置執行也存在固定的額外負擔,這使得每個訓練步驟執行多個圖變得不太理想。

與 Lazy Tensor 相比,這個差距使得它在真實世界的訓練用例中效率較低,尤其是在訓練中,追蹤成本可以與執行重疊。

重點

TorchDynamo 為編譯器後端提供了一種非常有前景的方式,可以向使用者隱藏複雜性,並輕鬆地以圖形格式檢索建模程式碼。與 PyTorch/XLA 傳統的 Lazy Tensor 圖形提取方式相比,TorchDynamo 可以跳過每個迭代的圖形追蹤,因此提供更好的推論回應時間。

大多數 PyTorch/XLA 支援的模型,在使用新的 dynamo-xla bridge 執行推論時,都看到了顯著的加速。我們的社群正努力擴展支援模型的集合。關於上述訓練功能差距,PyTorch/XLA 社群非常興奮能在我們即將進行的開發工作中改善訓練差距。團隊將繼續大力投資 TorchDynamo,並與上游合作,使訓練故事更加成熟。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源