• 文件 >
  • PyTorch/XLA SPMD 使用者指南
捷徑

PyTorch/XLA SPMD 使用者指南

在本使用者指南中,我們將討論 GSPMD 如何整合到 PyTorch/XLA 中,並提供設計概觀,以說明 SPMD 分片註釋 API 及其建構方式的運作原理。

什麼是 PyTorch/XLA SPMD?

GSPMD 是用於常見 ML 工作負載的自動並行化系統。XLA 編譯器將根據使用者提供的分片提示,將單一裝置程式轉換為具有適當集合運算的分割程式。此功能讓開發人員能夠編寫 PyTorch 程式,如同在單一大型裝置上,而無需任何自訂分片運算元和/或集合通訊來擴展。

Execution strategies

*圖 1. 兩種不同執行策略的比較,(a) 用於非 SPMD,(b) 用於 SPMD。*

如何使用 PyTorch/XLA SPMD?

以下是使用 SPMD 的簡單範例

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh


# Enable XLA SPMD execution mode.
xr.use_spmd()


# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))


t = torch.randn(8, 4).to(xm.xla_device())


# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)

讓我們逐一解釋這些概念

SPMD 模式

為了使用 SPMD,您需要透過 xr.use_spmd() 啟用它。在 SPMD 模式下,只有一個邏輯裝置。分散式運算和集合運算由 mark_sharding 處理。請注意,使用者不能將 SPMD 與其他分散式函式庫混合使用。

網格 (Mesh)

對於給定的裝置叢集,物理網格是互連拓撲的表示。

  1. mesh_shape 是一個元組,將乘以物理裝置總數。

  2. device_ids 幾乎總是 np.array(range(num_devices))

  3. 也鼓勵使用者為每個網格維度命名。在上面的範例中,第一個網格維度是 data 維度,第二個網格維度是 model 維度。

您也可以透過以下方式檢查更多網格資訊

>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])

分割規格 (Partition Spec)

partition_spec 與輸入張量具有相同的秩。每個維度描述了對應的輸入張量維度如何在裝置網格上分片。在上面的範例中,張量 t 的第一個維度在 data 維度上分片,第二個維度在 model 維度上分片。

使用者也可以對具有與網格形狀不同維度的張量進行分片。

t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)

# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))

# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))

# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))

延伸閱讀

  1. 範例,使用 SPMD 來表達資料並行性。

  2. 範例,使用 SPMD 來表達 FSDP(完全分片資料並行)。

  3. SPMD 進階主題

  4. Spmd 分散式檢查點

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源