捷徑

PJRT 執行階段

PyTorch/XLA 已從基於 TensorFlow 的 XRT 執行階段遷移至 PJRT 執行階段JAX 也使用此執行階段。

如果您遇到 PJRT 的錯誤,請在 GitHub 上提交問題,並加上 runtime 標籤。

PyTorch/XLA r2.1 中的新功能:

  • PJRT 在 PyTorch/XLA r2.1 中已穩定!

  • 公用執行階段 API 已從 torch_xla.experimental.pjrt 移至 torch_xla.runtime

    • pjrt:// 初始化方法已重新命名為 xla://,並由 torch_xla.distributed.xla_backend 註冊。

    • 先前的 torch_xla.experimental.* 名稱在此版本中仍然可用,以實現相容性。

  • 使用 init_method='xla://' 時,現在支援 torchrun

  • 透過 PJRT C API 針對 XPU 和 Neuron 的新外掛程式。

PyTorch/XLA r2.0 中的新功能:

  • 如果您未傳入任何其他執行階段組態,PJRT 將預設設定。如果您繼續設定 XRT 組態 (XRT_TPU_CONFIG),此變更不會產生任何影響

  • libtpu 中的新 TPU 執行階段實作將效能提升高達 30%。

  • 新的 xm.rendezvous 實作可擴展至數千個 TPU 核心

  • [實驗性] torch.distributed 支援 TPU v2 和 v3,包括 pjrt:// init_method

TL;DR

  • 若要使用 PJRT 預覽執行階段,請將 PJRT_DEVICE 環境變數設定為 CPUTPUCUDA

  • 在 XRT 中,所有分散式工作負載都是多進程的,每個裝置一個進程。在 PJRT 中的 TPU v2 和 v3 上,工作負載是多進程和多執行緒的 (4 個進程,每個進程 2 個執行緒),因此您的工作負載應為執行緒安全。請參閱 TPU v2/v3 上的多執行緒API 指南的多進程章節,以取得更多資訊。需要記住的主要差異

    • 若要以執行緒安全的方式初始化模型,請在初始化後跨複本廣播參數 (torch_xla.experimental.pjrt.broadcast_master_param),或從通用檢查點載入每個複本的參數。

    • 對於其他隨機數字產生,請盡可能使用 torch.Generator。即使您在各個複本中設定相同的 torch.manual_seed,全域 torch RNG 也不是執行緒安全的。

    • 若要使用 torch.distributed,請匯入 torch_xla.experimental.pjrt_backend 並使用 xla:// init_method

    • 這些步驟對於 GPU 和 TPU v4 是選用的。

從 XRT 到 PJRT 的範例差異

import os

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr


def _mp_fn(index):
  device = xm.xla_device()
-  dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+  dist.init_process_group('xla', init_method='xla://')

  torch.manual_seed(42)
  model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
  model = DDP(model, gradient_as_bucket_view=True)

  loss_fn = nn.MSELoss()
  optimizer = optim.SGD(model.parameters(), lr=.001)

  for i in range(10):
    data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()

    optimizer.step()
    xm.mark_step()

  # Print mean parameters so we can confirm they're the same across replicas
  print([p.mean() for p in model.parameters()])

if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

  torch_xla.launch(_mp_fn)

優點

  • 簡單的執行階段組態:只需將 PJRT_DEVICE 設定為 TPUCPUCUDA,即可開始使用 XLA!或者,讓 PJRT 根據您的環境自動選取裝置。

  • 效能提升:減少 gRPC 的額外負荷意味著更快的端對端執行。在 TorchBench 2.0 上,我們觀察到 TPU v4 上的訓練時間提升了 >35%。

  • 輕鬆的 Pod 執行:只需將您的程式碼複製到每個 TPU 工作站,然後使用 gcloud compute tpus tpuvm ssh --worker=all 同時執行它們。

  • 更佳的擴展性:移除 XRT 對參數大小的限制,並支援高達 2048 個 TPU 晶片。

快速入門

若要開始搭配 PyTorch/XLA 使用 PJRT,您只需設定 PJRT_DEVICE 環境變數即可。如果您正在使用 TPU v2 或 v3,請繼續閱讀以了解 TPU v2 和 v3 與 v4 之間的差異。

CPU

在任何安裝 PyTorch/XLA 的機器上,您都可以像這樣在 CPU 上執行我們的 MNIST 範例

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

TPU

若要建立安裝 PyTorch/XLA r2.0 的新 TPU

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,您可以像這樣執行我們的 ResNet50 範例

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

預設情況下,PJRT 將使用所有 TPU 晶片。若要僅使用一個 TPU 晶片,請設定 TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

Pods

在 TPU Pods 上,使用 gcloud 在每個 TPU 上平行執行您的命令

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

Docker

您也可以使用 Docker 在預先安裝 PyTorch/XLA 的容器中執行您的工作負載

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

請注意,docker run 需要對主機的特殊權限存取 (--privileged) 才能將 TPU 裝置公開給容器。目前僅在使用主機網路 --net=host 的情況下,才支援 TPU Pods 上的 Docker。如需更多資訊,請參閱 Cloud TPU 文件

GPU

單節點 GPU 訓練

若要搭配 PJRT 使用 GPU,只需設定 PJRT_DEVICE=CUDA 並設定 GPU_NUM_DEVICES 為主機上的裝置數量。例如

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您也可以使用 torchrun 來啟動單節點多 GPU 訓練。例如,

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上述範例中,--nnodes 表示要使用多少部機器 (實體機器或 VM) (由於我們執行單節點訓練,因此為 1)。--nproc-per-node 表示要使用多少個 GPU 裝置。

多節點 GPU 訓練

請注意,此功能僅適用於 cuda 12+。與 PyTorch 使用多節點訓練類似,您可以如下所示執行命令

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes:要使用多少部 GPU 機器。

  • --node_rank:目前 GPU 機器的索引。值可以是 0、1、…、${NUMBER_GPU_VM}-1。

  • --nproc_per_node:要在目前機器上使用的 GPU 裝置數量。

  • --rdzv_endpoint:node_rank==0 的 GPU 機器的端點,格式為 host:porthost 將會是內部 IP 位址。port 可以是機器上的任何可用連接埠。對於單節點訓練/推論,可以省略此參數。

例如,如果您想要在 2 部 GPU 機器上進行訓練:機器 0 和機器 1,請在第一部 GPU 機器機器 0 上執行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二部 GPU 機器上執行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

上述 2 個命令之間的差異在於 --node_rank,以及如果您想要在每部機器上使用不同數量的 GPU 裝置,則可能還有 --nproc_per_node。其餘所有內容都相同。如需有關 torchrun 的更多資訊,請參閱此頁面

與 XRT 的差異

雖然在大多數情況下,我們預期 PJRT 和 XRT 從終端使用者的角度來看,在很大程度上是可互換的 (尤其是在 TPU v4 上),但仍有一些細微的差異需要牢記在心。重要的是,XRT 是圍繞 TPU 節點架構設計的,因此即使在 TPU VM 上,它也始終會產生用戶端和伺服器進程。因此,每個輸入批次都有來自序列化和還原序列化資料以透過網路傳送資料的額外延遲。

PJRT 直接使用本機裝置,而沒有中間伺服器進程。在預設組態中,PJRT 將為每個 TPU 晶片建立一個進程,或每個 TPU 主機 4 個進程。如需有關 TPU 架構的更多資訊,請參閱 Cloud TPU 文件

  • 對於受額外負荷限制的工作負載,效能提升是可能的。

  • 在 XRT 下,伺服器進程是唯一與 TPU 裝置互動的進程,而用戶端進程無法直接存取 TPU 裝置。當分析單主機 TPU (例如 v3-8 或 v4-8) 時,您通常會看到 8 個裝置追蹤 (每個 TPU 核心一個)。使用 PJRT,每個進程都有一個晶片,而來自該進程的分析檔將僅顯示 2 個 TPU 核心。

    • 出於相同原因,分析在具有 XRT 的 TPU Pods 上不起作用,因為伺服器進程獨立於使用者的模型程式碼執行。PJRT 沒有該限制,因此可以在 TPU Pod 中為每個進程分析 2 個 TPU 核心。

  • PJRT 僅支援 TPU VM 架構,我們沒有計畫支援具有 PJRT 的 TPU 節點架構。

  • 使用 PJRT,執行階段組態明顯更簡單。xla_dist 不是執行 TPU Pod 工作負載的必要條件。而是將您的程式碼複製到每個 TPU 主機 ([gcloud compute tpus tpu-vm   scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)),並在每個主機上平行執行程式碼 (例如 [gcloud compute tpus tpu-vm   ssh --workers=all --command="PJRT_DEVICE=TPU python   run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))

  • xm.rendezvous 已使用 XLA 原生集體通訊重新實作,以增強大型 TPU Pods 的穩定性。請參閱下方以取得更多詳細資訊。

TPU v2/v3 上的多執行緒

在 TPU v2 和 v3 上,分散式工作負載始終以多執行緒方式執行,因為每個 TPU 核心將兩個 TPU 核心公開為裝置,並且一次只能有一個進程開啟一個 TPU 晶片。在其預設組態中,xmp.spawn 會自動產生盡可能多的進程 (每個 TPU 主機 4 個),並為每個進程建立兩個執行緒 (每個 TPU 核心一個)。

注意:在 TPU v4 上,每個 TPU 晶片都表示為一個 PyTorch 裝置,因此分散式工作負載將跨 4 個進程執行,每個進程只有一個執行緒。這與 XRT 的行為相同。

在大多數情況下,這不需要對您現有的程式碼進行實質性變更。在大多數情況下,您必須進行的主要變更是模型初始化。由於 torch 的全域 RNG 在執行緒之間共用,因此即使您在每個複本中將 torch.manual_seed 設定為相同的值,執行緒和執行之間的結果也會有所不同。若要在複本之間取得一致的參數,請使用 torch_xla.experimental.pjrt.broadcast_master_param 將一個複本的參數廣播到所有其他複本,或從通用檢查點載入每個複本的參數。

xm.rendezvous 的變更

PyTorch/XLA r2.0 中的新功能

使用 XRT,工作站 0 執行網格主機服務,並且所有工作站上的所有進程都透過 gRPC 連線至該服務。實際上,我們發現由於連線至工作站 0 的輸入連線數量,在具有數千個晶片的 TPU Pods 上執行單一網格主機進程是不可靠的。單一用戶端進程逾時可能會導致失敗,並強制整個工作負載重新啟動。

因此,我們已使用原生 XLA 集體通訊重新實作 xm.rendezvous,這在大型 TPU Pods 上更加穩定且經過充分測試。與 XRT 實作相比,這施加了兩個新的限制

  • 由於酬載必須成為 XLA 圖表的一部分,因此會在資料傳輸之前和之後都呼叫 xm.mark_step。在模型程式碼中間呼叫 xm.rendezvous 可能會強制進行不必要的編譯。

  • 由於 XLA 不允許集體運算在工作站的子集上執行,因此所有工作站都必須參與 rendezvous

如果您需要 xm.rendezvous 的舊行為 (即,在不變更 XLA 圖表和/或同步工作站子集的情況下傳輸資料),請考慮使用 `torch.distributed.barrier <https://pytorch.dev.org.tw/docs/stable/distributed.html#torch.distributed.barrier&gt[__ 或 ]{.title-ref}torch.distributed.all_gather_object <https://pytorch.dev.org.tw/docs/stable/distributed.html#torch.distributed.all_gather_object&gt[__ 搭配 ]{.title-ref}[gloo]{.title-ref}[ 進程群組。如果您也使用 ]{.title-ref}[xla]{.title-ref}[ ]{.title-ref}[torch.distributed]{.title-ref}[ 後端,則可以使用 ]{.title-ref}[torch.new*group]{.title-ref}[ 建立 ]{.title-ref}[gloo]{.title-ref}[ 子群組。請參閱 `此範例 https://pytorch.dev.org.tw/docs/stable/distributed.html#monitored-barrier]{.title-ref}*_,取自 PyTorch 文件。請記住以下限制

  • torch.distributed 在 TPU v2/v3 上未完全支援。僅實作了具有 xla 後端的運算子集,並且 gloo 在多執行緒內容中可能無法如預期般運作。

  • 在我們的實驗中,gloo 無法很好地擴展到數千個 TPU 晶片,因此預期此替代方案在大型規模下不如使用 PJRT 的 xm.rendezvous 可靠。

PJRT 和 torch.distributed

PyTorch/XLA r2.0 中的新功能

當搭配 torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) 使用 PJRT 時,我們強烈建議使用新的 xla:// init_method,它會透過查詢執行階段自動尋找複本 ID、世界大小和主機 IP。例如

import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  torch_xla.launch(_all_gather)

注意:雖然在 TPU v4 上不需要 xla:// init_method,但仍建議使用。如果您使用 env://,則 MASTER_ADDR 必須設定為具有裝置 0 的 IP 主機,該主機一定是工作站 0。xla:// init_method 會自動尋找此 IP。

注意:對於 TPU v2/v3,您仍然需要匯入 torch_xla.experimental.pjrt_backend,因為 torch.distributed 中的 TPU v2/v3 支援仍處於實驗階段。

如需有關在 PyTorch/XLA 上使用 DistributedDataParallel 的更多資訊,請參閱 TPU V4 上的 ddp.md。如需同時使用 DDP 和 PJRT 的範例,請在 TPU 上執行以下 範例指令碼

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

效能

TorchBench 顯示,與 XRT 相比,PJRT 在各項任務中的平均訓練時間有所改善,TPU v4-8 的平均改善幅度超過 35%。好處因任務和模型類型而異,範圍從 0% 到 175%。下圖顯示了按任務劃分的細目

PJRT vs XRT

新的 TPU 執行階段

PyTorch/XLA r2.0 中的新功能

PyTorch/XLA r2.0 版本引入了對 PJRT 外掛程式 API 的支援,該 API 用於存取 libtpu 中新的基於 TFRT 的 TPU 執行階段。現在,當設定 PJRT_DEVICE=TPU 時,這是預設執行階段。在 1.13 中使用的舊版基於 StreamExecutor 的 TPU 執行階段在 2.0 版本中仍然可以使用 PJRT_DEVICE=TPU_LEGACY,但在未來版本中將會移除。如果您遇到僅在 TPU 上發生而不是 TPU_LEGACY 上發生的問題,請在 GitHub 上提交問題。

在大多數情況下,我們預期這兩個執行階段之間的效能相似,但在某些情況下,新的執行階段可能會快達 30%。下圖顯示了按任務劃分的細目

TFRT vs StreamExecutor

注意:此圖表中顯示的改進也包含在 PJRT 與 XRT 的比較中。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源