• 文件 >
  • PyTorch/XLA SPMD 進階主題
快速鍵

PyTorch/XLA SPMD 進階主題

在這份文件中,我們將涵蓋 GSPMD 的一些進階主題。在繼續閱讀本文檔之前,請先閱讀 SPMD 使用者指南

分片感知的主機到裝置資料載入

PyTorch/XLA SPMD 採用單裝置程式,將其分片並平行執行。SPMD 執行需要使用原生 PyTorch DataLoader,它會從主機同步傳輸資料到 XLA 裝置。這會在每個步驟的輸入資料傳輸期間阻礙訓練。為了提升原生資料載入效能,我們讓 PyTorch/XLA ParallelLoader 直接支援輸入分片 (src),當傳遞選用的 kwarg _input_sharding_ 時

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
         train_loader,  # wraps PyTorch DataLoader
         device,
     # assume 4d input and we want to shard at the batch dimension.
         input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))

如果批次中的每個元素形狀不同,也可以為其指定不同的 input_sharding

# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
         train_loader,  # wraps PyTorch DataLoader
         device,
     # specify different sharding for each input of the batch.
         input_sharding={
          'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
          'y': xs.ShardingSpec(input_mesh, ('data', None))
        }
)

虛擬裝置最佳化

PyTorch/XLA 通常會在張量定義後,從主機非同步傳輸張量資料到裝置。這是為了使資料傳輸與圖形追蹤時間重疊。然而,由於 GSPMD 允許使用者在張量定義_後_修改張量分片,我們需要一種最佳化方法來防止在主機和裝置之間不必要的張量資料來回傳輸。我們引入虛擬裝置最佳化,這是一種將張量資料先放置在虛擬裝置 SPMD:0 上的技術,然後在所有分片決策完成後再上傳到實體裝置。SPMD 模式中的每個張量資料都放置在虛擬裝置 SPMD:0 上。虛擬裝置以 XLA 裝置 XLA:0 的形式向使用者公開,而實際的分片則位於實體裝置上,例如 TPU:0、TPU:1 等。

混合網格

網格很好地抽象化了實體裝置網格的建構方式。使用者可以使用邏輯網格以任何形狀和順序排列裝置。然而,可以根據實體拓撲定義效能更高的網格,尤其是在涉及資料中心網路 (DCN) 跨切片連線時。HybridMesh 建立了一個網格,可以在此類多切片環境中提供良好的開箱即用效能。它接受 ici_mesh_shape 和 dcn_mesh_shape,它們表示內部和外部網路的邏輯網格形狀。

from torch_xla.distributed.spmd import HybridMesh

# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

在 TPU Pod 上執行 SPMD

如果您根據裝置數量而不是某些硬式編碼常數來建構網格和分割規格,則從單一 TPU 主機移轉到 TPU Pod 不需要任何程式碼變更。若要在 TPU Pod 上執行 PyTorch/XLA 工作負載,請參閱我們 PJRT 指南的 Pods 章節

XLAShardedTensor

xs.mark_sharding 是一個原地 (inplace) 運算,它會將分片註解附加到輸入張量,但它也會傳回一個 XLAShardedTensor python 物件。

XLAShardedTensor 的主要用例 [RFC] 是使用分片規格註解原生 torch.tensor (在單一裝置上)。註解會立即發生,但張量的實際分片會延遲,因為計算是以延遲方式執行的,輸入張量除外,輸入張量會立即分片。一旦張量被註解並包裝在 XLAShardedTensor 中,就可以將其作為 torch.Tensor 傳遞給現有的 PyTorch 運算和 nn.Module 層。這對於確保相同的 PyTorch 層和張量運算可以與 XLAShardedTensor 堆疊在一起非常重要。這表示使用者不需要為分片運算重寫現有的運算和模型程式碼。也就是說,XLAShardedTensor 將滿足以下要求

  • XLAShardedTensortorch.Tensor 的子類別,可直接與原生 torch 運算和 module.layers 搭配使用。我們使用 __torch_dispatch__XLAShardedTensor 傳送到 XLA 後端。PyTorch/XLA 擷取附加的分片註解以追蹤圖形,並調用 XLA SPMDPartitioner。

  • 在內部,XLAShardedTensor (及其 global_tensor 輸入) 由 XLATensor 支援,XLATensor 具有特殊的資料結構,用於保存對分片裝置資料的參考。

  • 延遲執行後的分片張量可能會在主機上請求時 (例如,列印全域張量的值) 收集並實體化回主機作為 global_tensor。

  • 本機分片的控制代碼嚴格在延遲執行後才實體化。XLAShardedTensor 公開 local_shards 以傳回可定址裝置上的本機分片,格式為 List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]

目前也正在努力將 XLAShardedTensor 整合到 DistributedTensor API 中,以支援 XLA 後端 [RFC]。

DTensor 整合

PyTorch 在 2.1 中原型發布了 DTensor。我們正在將 PyTorch/XLA SPMD 整合到 DTensor API RFC 中。我們針對 distribute_tensor 進行了概念驗證整合,它會呼叫 mark_sharding 註解 API,以使用 XLA 分片張量及其計算

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

此功能為實驗性功能,請隨時關注即將發布版本中的更多更新、範例和教學。

torch.compile 的啟動分片

在 2.3 版本中,PyTorch/XLA 新增了自訂運算 dynamo_mark_sharding,可用於在 torch.compile 區域中執行啟動分片。這是我們持續努力的一部分,旨在使 torch.compile + GSPMD 成為使用 PyTorch/XLA 進行模型推論的建議方式。以下是使用此自訂運算的範例

# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)

SPMD 偵錯工具

我們為 PyTorch/XLA SPMD 使用者在 TPU/GPU/CPU 上,針對單主機/多主機提供了 shard placement visualization debug tool:您可以使用 visualize_tensor_sharding 來視覺化分片張量,或者您可以使用 visualize_sharding 來視覺化分片字串。以下是在 TPU 單主機 (v4-8) 上使用 visualize_tensor_shardingvisualize_sharding 的兩個程式碼範例

  • 使用的程式碼片段 visualize_tensor_sharding 和視覺化結果

import rich

# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))

# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
visualize_tensor_sharding example on TPU v4-8(single-host)
  • 使用的程式碼片段 visualize_sharding 和視覺化結果

from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
visualize_sharding example on TPU v4-8(single-host)

您可以在 TPU/GPU/CPU 單主機上使用這些範例,並修改它們以在多主機上執行。而且您可以將其修改為分片樣式 tiledpartial_replicationreplicated

自動分片

我們正在推出一個新的 PyTorch/XLA SPMD 功能,稱為 auto-shardingRFC。這是 r2.3nightly 中的實驗性功能,支援 XLA:TPU 和單一 TPUVM 主機。

可以透過以下方式之一啟用 PyTorch/XLA 自動分片

  • 設定環境變數 XLA_AUTO_SPMD=1

  • 在程式碼開頭呼叫 SPMD API

import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • 使用 auto-policyxla 呼叫 pytorch.distributed._tensor.distribute_module

import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

或者,可以設定以下選項/環境變數來控制基於 XLA 的自動分片傳遞的行為

  • XLA_AUTO_USE_GROUP_SHARDING:參數的群組重新分片。預設設定。

  • XLA_AUTO_SPMD_MESH:用於自動分片的邏輯網格形狀。例如,XLA_AUTO_SPMD_MESH=2,2 對應於具有 4 個全域裝置的 2x2 網格。如果未設定,將使用預設裝置網格形狀 num_devices,1

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源