TorchDynamo 整合至 PyTorch XLA¶
TorchDynamo 是一個 Python 層級的 JIT 編譯器,旨在使未修改的 PyTorch 程式執行得更快。它為編譯器後端提供了一個清晰的 API 以進行掛鉤,其最大的特色是在 Python 位元組碼執行之前動態修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 為 TorchDynamo 提供了實驗性的後端,用於推論和訓練。
XLA bridge 的運作方式是,當 Dynamo 識別出模型模式時,它將提供一個 TorchFX 圖,而 PyTorch/XLA 將使用現有的 Lazy Tensor 技術來編譯 FX 圖並傳回已編譯的函式。
整合¶
目前透過將 backend='openxla'
引數新增至 torch.compile
,即可支援 PyTorch/XLA 和 Dynamo。例如:
import torch
import torch_xla.core.xla_model as xm
def add(a, b):
a_xla = a.to(xm.xla_device())
b_xla = b.to(xm.xla_device())
return a_xla + b_xla
compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))
推論¶
以下是以 torch.compile
執行 resnet18 的小程式碼範例:
import torch
import torchvision
import torch_xla.core.xla_model as xm
def eval_model(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
xla_resnet18, backend='openxla')
for data, _ in loader:
with torch.no_grad():
output = dynamo_resnet18(data)
使用 torch.compile
,您將看到 PyTorch/XLA 只會在初始化時追蹤 resent18 模型一次,並在每次調用 dynamo_resnet18
時執行已編譯的二進位檔,而不是每次都追蹤模型。以下是在 Cloud TPU v4-8 上使用 torch bench 比較 Dynamo 和 Lazy 的推論速度分析:
模型 | 加速 |
---|---|
resnet18 | 2.59 |
resnet50 | 2.64 |
resnext50_32x4d | 1.91 |
alexnet | 1.28 |
mobilenet_v2 | 18.62 |
mnasnet1_0 | 2.68 |
vgg16 | 1.33 |
BERT_pytorch | 7.49 |
squeezenet1_1 | 2.29 |
timm_vision_transformer | 3.52 |
幾何平均數 | 3.04 |
訓練¶
PyTorch/XLA 也支援 Dynamo 進行訓練,但它仍處於實驗階段,我們正與 PyTorch Compiler 團隊合作迭代實作。以下是以 torch.compile
訓練 resnet18 的範例:
import torch
import torchvision
import torch_xla.core.xla_model as xm
def train_model(model, data, target, optimizer):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
loss = loss_fn(pred, target)
loss.backward()
optimizer.step()
return pred
def train_model_main(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
train_model, backend='openxla')
for data, target in loader:
xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)
如果您使用 Lazy tensor,我們預期每個訓練步驟會提取和執行 3 個圖,而不是每個訓練步驟 1 個圖。以下是在 Cloud TPU v4-8 上使用 torch bench 比較 Dynamo 和 Lazy 的訓練速度分析。
模型 | 加速 |
---|---|
resnet50 | 1.33 |
resnet18 | 1.33 |
BERT_pytorch | 3.07 |
resnext50_32x4d | 1.43 |
alexnet | 1.12 |
mobilenet_v2 | 1.4 |
mnasnet1_0 | 1.19 |
vgg16 | 0.81 |
timm_vision_transformer | 1.87 |
squeezenet1_1 | 1.41 |
幾何平均數 | 1.41 |
注意: 我們針對每個模型的 fwd 和 bwd 執行單一步驟,然後收集端對端時間。在現實世界中,我們將在每個訓練工作中執行多個步驟,這可以輕鬆地隱藏來自執行的追蹤成本 (因為它是非同步的)。在這種情況下,Lazy Tensor 將具有更好的效能。
功能差距¶
我們想指出一個差距,這個差距阻礙了我們在更大規模的模型上使用 TorchDynamo。
TorchDynamo 將向前和向後追蹤到不同的圖中。對於 PyTorch/XLA 來說,讓 XLA 編譯器將整個步驟視為一個圖,以最佳化速度非常重要。啟動每個裝置執行也存在固定的額外負擔,這使得每個訓練步驟執行多個圖變得不太理想。
與 Lazy Tensor 相比,這個差距使得它在真實世界的訓練用例中效率較低,尤其是在訓練中,追蹤成本可以與執行重疊。
重點¶
TorchDynamo 為編譯器後端提供了一種非常有前景的方式,可以向使用者隱藏複雜性,並輕鬆地以圖形格式檢索建模程式碼。與 PyTorch/XLA 傳統的 Lazy Tensor 圖形提取方式相比,TorchDynamo 可以跳過每個迭代的圖形追蹤,因此提供更好的推論回應時間。
大多數 PyTorch/XLA 支援的模型,在使用新的 dynamo-xla bridge 執行推論時,都看到了顯著的加速。我們的社群正努力擴展支援模型的集合。關於上述訓練功能差距,PyTorch/XLA 社群非常興奮能在我們即將進行的開發工作中改善訓練差距。團隊將繼續大力投資 TorchDynamo,並與上游合作,使訓練故事更加成熟。