• 教學課程 >
  • 使用 torchrun 進行容錯分散式訓練
快捷方式

簡介 || 什麼是 DDP || 單節點多 GPU 訓練 || 容錯 || 多節點訓練 || minGPT 訓練

使用 torchrun 進行容錯分散式訓練

建立於:2022 年 9 月 27 日 | 最後更新:2024 年 11 月 12 日 | 最後驗證:2024 年 11 月 05 日

作者:Suraj Subramanian

您將學到什麼
  • 使用 torchrun 啟動多 GPU 訓練任務

  • 儲存和載入您的訓練任務的快照

  • 建構您的訓練腳本以進行優雅的重新啟動

GitHub 上檢視本教學中使用的程式碼

先決條件
  • DDP 的高階概述

  • 熟悉 DDP 程式碼

  • 具有多個 GPU 的機器(本教學使用 AWS p3.8xlarge 執行個體)

  • 安裝並帶有 CUDA 的 PyTorch

請觀看以下影片或在 youtube 上觀看。

在分散式訓練中,單一進程故障可能會中斷整個訓練任務。由於此處發生故障的可能性較高,因此使您的訓練腳本具有強大的容錯性尤其重要。您可能還希望您的訓練任務是彈性的,例如,計算資源可以在任務過程中動態地加入和離開。

PyTorch 提供一個稱為 torchrun 的工具,它提供容錯和彈性訓練。發生故障時,torchrun 會記錄錯誤並嘗試從訓練任務上次儲存的「快照」自動重新啟動所有進程。

快照儲存的不僅僅是模型狀態;它還可以包括有關已運行的 epoch 數量、優化器狀態或訓練任務的任何其他有狀態屬性的詳細資訊,這些資訊對於其連續性是必要的。

為什麼使用 torchrun

torchrun 處理分散式訓練的細節,因此您無需這樣做。例如,

  • 您無需設定環境變數或明確傳遞 rankworld_sizetorchrun 會將其與其他幾個環境變數一起分配。

  • 無需在您的腳本中呼叫 mp.spawn;您只需要一個通用的 main() 進入點,並使用 torchrun 啟動腳本。這樣,相同的腳本可以在非分散式以及單節點和多節點設定中運行。

  • 從上次儲存的訓練快照優雅地重新啟動訓練。

優雅的重新啟動

為了實現優雅的重新啟動,您應該像這樣建構您的訓練腳本:

def main():
  load_snapshot(snapshot_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_snapshot(snapshot_path)

如果發生故障,torchrun 將終止所有進程並重新啟動它們。每個進程進入點首先載入並初始化上次儲存的快照,然後從那裡繼續訓練。因此,在任何故障時,您只會失去上次儲存的快照中的訓練進度。

在彈性訓練中,每當有任何成員資格變更(新增或移除節點)時,torchrun 將終止並在可用裝置上產生進程。擁有此結構可確保您的訓練任務可以繼續,而無需手動干預。

multigpu.py v/s multigpu_torchrun.py 的差異

進程群組初始化

- def ddp_setup(rank, world_size):
+ def ddp_setup():
-     """
-     Args:
-         rank: Unique identifier of each process
-         world_size: Total number of processes
-     """
-     os.environ["MASTER_ADDR"] = "localhost"
-     os.environ["MASTER_PORT"] = "12355"
-     init_process_group(backend="nccl", rank=rank, world_size=world_size)
+     init_process_group(backend="nccl")
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

使用 torchrun 提供的環境變數

- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])

儲存和載入快照

定期將所有相關資訊儲存在快照中,可讓我們的訓練作業在發生中斷後無縫恢復。

+ def _save_snapshot(self, epoch):
+     snapshot = {}
+     snapshot["MODEL_STATE"] = self.model.module.state_dict()
+     snapshot["EPOCHS_RUN"] = epoch
+     torch.save(snapshot, "snapshot.pt")
+     print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")

+ def _load_snapshot(self, snapshot_path):
+     snapshot = torch.load(snapshot_path)
+     self.model.load_state_dict(snapshot["MODEL_STATE"])
+     self.epochs_run = snapshot["EPOCHS_RUN"]
+     print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

在 Trainer 建構函式中載入快照

重新啟動中斷的訓練作業時,您的腳本會先嘗試載入快照,以從中恢復訓練。

class Trainer:
   def __init__(self, snapshot_path, ...):
   ...
+  if os.path.exists(snapshot_path):
+     self._load_snapshot(snapshot_path)
   ...

恢復訓練

訓練可以從上次運行的 epoch 恢復,而不是從頭開始。

def train(self, max_epochs: int):
-  for epoch in range(max_epochs):
+  for epoch in range(self.epochs_run, max_epochs):
      self._run_epoch(epoch)

執行腳本

只需像對待非多處理腳本一樣呼叫您的進入點函式;torchrun 會自動產生進程。

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  world_size = torch.cuda.device_count()
-  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+  main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

延伸閱讀


評價本教學課程

© 版權所有 2024, PyTorch。

使用 Sphinx 構建,使用 themeRead the Docs 提供。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源