• 教學 >
  • 使用 DDP 進行多 GPU 訓練
捷徑

簡介 || 什麼是 DDP || 單節點多 GPU 訓練 || 容錯 || 多節點訓練 || minGPT 訓練

使用 DDP 進行多 GPU 訓練

建立於:2022 年 9 月 27 日 | 最後更新:2024 年 11 月 03 日 | 最後驗證:未驗證

作者:Suraj Subramanian

您將學到什麼
  • 如何透過 DDP 將單 GPU 訓練腳本遷移到多 GPU

  • 設定分散式處理程序群組

  • 在分散式設定中儲存和載入模型

GitHub 上檢視本教學中使用的程式碼

先決條件
  • DDP 如何運作 的高階概觀

  • 具有多個 GPU 的機器(本教學使用 AWS p3.8xlarge 實例)

  • 安裝並具有 CUDA 的 PyTorch

觀看以下影片或在 youtube 上觀看。

先前的教學中,我們對 DDP 的運作方式進行了高階概觀;現在我們看看如何在程式碼中使用 DDP。在本教學中,我們從單 GPU 訓練腳本開始,並將其遷移到在單個節點上的 4 個 GPU 上執行。在此過程中,我們將討論分散式訓練中的重要概念,同時在我們的程式碼中實作它們。

注意

如果您的模型包含任何 BatchNorm 層,則需要將其轉換為 SyncBatchNorm,以同步所有副本的 BatchNorm 層的運行統計資料。

使用輔助函數 torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 將模型中的所有 BatchNorm 層轉換為 SyncBatchNorm

single_gpu.py v/s multigpu.py 的差異

這些是您通常對單 GPU 訓練腳本所做的更改,以啟用 DDP。

匯入

  • torch.multiprocessing 是 PyTorch 對 Python 原生多處理的封裝

  • 分散式處理程序群組包含所有可以相互通信和同步的處理程序。

import torch
import torch.nn.functional as F
from utils import MyTrainDataset

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

建構處理程序群組

  • 首先,在初始化群組處理程序之前,呼叫 set_device,它會為每個處理程序設定預設 GPU。這對於防止 GPU:0 上的掛起或過度記憶體使用至關重要

  • 處理程序群組可以透過 TCP (預設) 或從共用檔案系統初始化。閱讀更多關於處理程序群組初始化

  • init_process_group 初始化分散式處理程序群組。

  • 閱讀更多關於選擇 DDP 後端

def ddp_setup(rank: int, world_size: int):
   """
   Args:
       rank: Unique identifier of each process
      world_size: Total number of processes
   """
   os.environ["MASTER_ADDR"] = "localhost"
   os.environ["MASTER_PORT"] = "12355"
   torch.cuda.set_device(rank)
   init_process_group(backend="nccl", rank=rank, world_size=world_size)

建構 DDP 模型

self.model = DDP(model, device_ids=[gpu_id])

分配輸入資料

  • DistributedSampler 將輸入資料分塊到所有分散式處理程序中。

  • DataLoader 結合了資料集和

    取樣器,並提供對給定資料集的可迭代物件。

  • 每個進程將接收一個包含 32 個樣本的輸入批次;有效的批次大小為 32 * nprocs,或者在使用 4 個 GPU 時為 128。

train_data = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,  # We don't shuffle
    sampler=DistributedSampler(train_dataset), # Use the Distributed Sampler here.
)
  • 在每個 epoch 的開始時,調用 DistributedSampler 上的 set_epoch() 方法是必要的,以確保在多個 epoch 之間正確地進行 shuffle。否則,每個 epoch 都會使用相同的排序。

def _run_epoch(self, epoch):
    b_sz = len(next(iter(self.train_data))[0])
    self.train_data.sampler.set_epoch(epoch)   # call this additional line at every epoch
    for source, targets in self.train_data:
      ...
      self._run_batch(source, targets)

儲存模型檢查點

  • 我們只需要從一個進程儲存模型檢查點。如果沒有這個條件,每個進程都會儲存一個相同的模型副本。有關使用 DDP 儲存和載入模型的更多資訊,請參閱這裡

- ckp = self.model.state_dict()
+ ckp = self.model.module.state_dict()
...
...
- if epoch % self.save_every == 0:
+ if self.gpu_id == 0 and epoch % self.save_every == 0:
  self._save_checkpoint(epoch)

警告

Collective calls 是在所有分散式進程上執行的函數,它們用於將某些狀態或值收集到特定進程。 Collective calls 要求所有 rank 都執行 collective code。 在這個例子中, _save_checkpoint 不應該有任何 collective calls,因為它只在 rank:0 進程上執行。 如果您需要進行任何 collective calls,它應該在 if self.gpu_id == 0 檢查之前。

執行分散式訓練任務

  • 包含新的參數 rank (取代 device) 和 world_size

  • rank 在呼叫 mp.spawn 時由 DDP 自動分配。

  • world_size 是訓練任務中進程的數量。 對於 GPU 訓練,這對應於使用的 GPU 數量,每個進程都在專用的 GPU 上工作。

- def main(device, total_epochs, save_every):
+ def main(rank, world_size, total_epochs, save_every):
+  ddp_setup(rank, world_size)
   dataset, model, optimizer = load_train_objs()
   train_data = prepare_dataloader(dataset, batch_size=32)
-  trainer = Trainer(model, train_data, optimizer, device, save_every)
+  trainer = Trainer(model, train_data, optimizer, rank, save_every)
   trainer.train(total_epochs)
+  destroy_process_group()

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  device = 0      # shorthand for cuda:0
-  main(device, total_epochs, save_every)
+  world_size = torch.cuda.device_count()
+  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)

以下是程式碼的樣子

延伸閱讀

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源