• 文件 >
  • PyTorch XLA 中的完全分片數據並行
快捷方式

PyTorch XLA 中的完全分片數據並行

PyTorch XLA 中的完全分片數據並行 (FSDP) 是一種用於跨數據並行工作者分片模組參數的實用程式。

使用範例

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以單獨分片個別層,並讓外部包裝器處理任何剩餘的參數。

注意事項:XlaFullyShardedDataParallel 類別同時支援 ZeRO-2 優化器 (分片梯度和優化器狀態) 和 ZeRO-3 優化器 (分片參數、梯度和優化器狀態),詳見 https://arxiv.org/abs/1910.02054。ZeRO-3 優化器應透過巢狀 FSDP 與 reshard_after_forward=True 實作。範例請參閱 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py。 * 對於無法放入單一 TPU 記憶體或主機 CPU 記憶體的大型模型,應將子模組建構與內部 FSDP 包裝交錯進行。範例請參閱 FSDPViTModel。提供了一個簡單的包裝器 checkpoint_module (基於 torch_xla.utils.checkpoint.checkpoint,來自 https://github.com/pytorch/xla/pull/3524) 以對給定的 nn.Module 實例執行 梯度檢查點。範例請參閱 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py。子模組自動包裝:除了手動巢狀 FSDP 包裝外,還可以指定 auto_wrap_policy 參數,以自動使用內部 FSDP 包裝子模組。torch_xla.distributed.fsdp.wrap 中的 size_based_auto_wrap_policyauto_wrap_policy 可調用物件的範例,此策略會包裝參數數量大於 1 億的層。torch_xla.distributed.fsdp.wrap 中的 transformer_auto_wrap_policy 是用於類變壓器模型架構的 auto_wrap_policy 可調用物件的範例。

例如,若要自動使用內部 FSDP 包裝所有 torch.nn.Conv2d 子模組,可以使用

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,也可以指定 auto_wrapper_callable 參數,以針對子模組使用自訂可調用包裝器 (預設包裝器僅為 XlaFullyShardedDataParallel 類別本身)。例如,可以使用以下程式碼將梯度檢查點 (即啟用檢查點/重新實體化) 應用於每個自動包裝的子模組。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在逐步執行優化器時,直接呼叫 optimizer.step,而不要呼叫 xm.optimizer_step。後者會減少跨 rank 的梯度,這對於 FSDP 來說是不需要的 (因為參數已經分片)。

  • 在訓練期間儲存模型和優化器檢查點時,每個訓練程序都需要儲存其自己的 (分片) 模型和優化器狀態字典的檢查點 (使用 master_only=False 並在 xm.save 中為每個 rank 設定不同的路徑)。恢復時,需要載入對應 rank 的檢查點。

  • 也請儲存 model.get_shard_metadata() 以及 model.state_dict(),如下所示,並使用 consolidate_sharded_model_checkpoints 將分片模型檢查點縫合在一起成為完整的模型狀態字典。範例請參閱 test/test_train_mp_mnist_fsdp_with_ckpt.py

ckpt = {
    'model': model.state_dict(),
    'shard_metadata': model.get_shard_metadata(),
    'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
  • 檢查點合併腳本也可以從命令列啟動,如下所示。

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /path/to/your_sharded_checkpoint_files \
  --ckpt_suffix "_rank-*-of-*.pth"

此類別的實作很大程度上受到 fairscale.nn.FullyShardedDataParallelhttps://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 中的結構啟發,並且主要遵循該結構。fairscale.nn.FullyShardedDataParallel 最大的差異之一是,在 XLA 中,我們沒有明確的參數儲存,因此在這裡我們採用不同的方法來釋放 ZeRO-3 的完整參數。

MNIST 和 ImageNet 上的訓練腳本範例

安裝

FSDP 在 PyTorch/XLA 1.12 發行版和更新的 nightly 版本中提供。請參閱 https://github.com/pytorch/xla#-available-images-and-wheels 以取得安裝指南。

複製 PyTorch/XLA repo

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上訓練 MNIST

它在 2 個 epoch 中獲得約 98.9 的準確度

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此腳本會在結尾自動測試檢查點合併。您也可以透過以下方式手動合併分片檢查點

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 訓練 ImageNet

它在 100 個 epoch 中獲得約 75.9 的準確度;將 ImageNet-1k 下載到 /datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您也可以新增 --use_gradient_checkpointing (需要與 --use_nested_fsdp--auto_wrap_policy 一起使用),以在殘差區塊上應用梯度檢查點。

在 TPU pod 上 (具有 100 億個參數) 的訓練腳本範例

若要訓練無法放入單一 TPU 的大型模型,在建構整個模型以實作 ZeRO-3 演算法時,應使用 auto-wrap 或手動使用內部 FSDP 包裝子模組。

請參閱 https://github.com/ronghanghu/vit_10b_fsdp_example,以取得使用此 XLA FSDP PR 分片訓練 Vision Transformer (ViT) 模型的範例。

文件

存取 PyTorch 的全面開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源