簡介 || 什麼是 DDP || 單節點多 GPU 訓練 || 容錯 || 多節點訓練 || minGPT 訓練
使用 DDP 進行多 GPU 訓練¶
建立於:2022 年 9 月 27 日 | 最後更新:2024 年 11 月 03 日 | 最後驗證:未驗證
觀看以下影片或在 youtube 上觀看。
在先前的教學中,我們對 DDP 的運作方式進行了高階概觀;現在我們看看如何在程式碼中使用 DDP。在本教學中,我們從單 GPU 訓練腳本開始,並將其遷移到在單個節點上的 4 個 GPU 上執行。在此過程中,我們將討論分散式訓練中的重要概念,同時在我們的程式碼中實作它們。
注意
如果您的模型包含任何 BatchNorm
層,則需要將其轉換為 SyncBatchNorm
,以同步所有副本的 BatchNorm
層的運行統計資料。
使用輔助函數 torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 將模型中的所有 BatchNorm
層轉換為 SyncBatchNorm
。
single_gpu.py v/s multigpu.py 的差異
這些是您通常對單 GPU 訓練腳本所做的更改,以啟用 DDP。
匯入¶
torch.multiprocessing
是 PyTorch 對 Python 原生多處理的封裝分散式處理程序群組包含所有可以相互通信和同步的處理程序。
import torch
import torch.nn.functional as F
from utils import MyTrainDataset
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
建構處理程序群組¶
首先,在初始化群組處理程序之前,呼叫 set_device,它會為每個處理程序設定預設 GPU。這對於防止 GPU:0 上的掛起或過度記憶體使用至關重要
處理程序群組可以透過 TCP (預設) 或從共用檔案系統初始化。閱讀更多關於處理程序群組初始化
init_process_group 初始化分散式處理程序群組。
閱讀更多關於選擇 DDP 後端
def ddp_setup(rank: int, world_size: int):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
init_process_group(backend="nccl", rank=rank, world_size=world_size)
建構 DDP 模型¶
self.model = DDP(model, device_ids=[gpu_id])
分配輸入資料¶
DistributedSampler 將輸入資料分塊到所有分散式處理程序中。
- DataLoader 結合了資料集和
取樣器,並提供對給定資料集的可迭代物件。
每個進程將接收一個包含 32 個樣本的輸入批次;有效的批次大小為
32 * nprocs
,或者在使用 4 個 GPU 時為 128。
train_data = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=False, # We don't shuffle
sampler=DistributedSampler(train_dataset), # Use the Distributed Sampler here.
)
在每個 epoch 的開始時,調用
DistributedSampler
上的set_epoch()
方法是必要的,以確保在多個 epoch 之間正確地進行 shuffle。否則,每個 epoch 都會使用相同的排序。
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
self.train_data.sampler.set_epoch(epoch) # call this additional line at every epoch
for source, targets in self.train_data:
...
self._run_batch(source, targets)
儲存模型檢查點¶
我們只需要從一個進程儲存模型檢查點。如果沒有這個條件,每個進程都會儲存一個相同的模型副本。有關使用 DDP 儲存和載入模型的更多資訊,請參閱這裡
- ckp = self.model.state_dict()
+ ckp = self.model.module.state_dict()
...
...
- if epoch % self.save_every == 0:
+ if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
警告
Collective calls 是在所有分散式進程上執行的函數,它們用於將某些狀態或值收集到特定進程。 Collective calls 要求所有 rank 都執行 collective code。 在這個例子中, _save_checkpoint 不應該有任何 collective calls,因為它只在 rank:0
進程上執行。 如果您需要進行任何 collective calls,它應該在 if self.gpu_id == 0
檢查之前。
執行分散式訓練任務¶
包含新的參數
rank
(取代device
) 和world_size
。rank
在呼叫 mp.spawn 時由 DDP 自動分配。world_size
是訓練任務中進程的數量。 對於 GPU 訓練,這對應於使用的 GPU 數量,每個進程都在專用的 GPU 上工作。
- def main(device, total_epochs, save_every):
+ def main(rank, world_size, total_epochs, save_every):
+ ddp_setup(rank, world_size)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size=32)
- trainer = Trainer(model, train_data, optimizer, device, save_every)
+ trainer = Trainer(model, train_data, optimizer, rank, save_every)
trainer.train(total_epochs)
+ destroy_process_group()
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
- device = 0 # shorthand for cuda:0
- main(device, total_epochs, save_every)
+ world_size = torch.cuda.device_count()
+ mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
以下是程式碼的樣子