簡介 || 什麼是 DDP || 單節點多 GPU 訓練 || 容錯 || 多節點訓練 || minGPT 訓練
使用 torchrun
進行容錯分散式訓練¶
建立於:2022 年 9 月 27 日 | 最後更新:2024 年 11 月 12 日 | 最後驗證:2024 年 11 月 05 日
請觀看以下影片或在 youtube 上觀看。
在分散式訓練中,單一進程故障可能會中斷整個訓練任務。由於此處發生故障的可能性較高,因此使您的訓練腳本具有強大的容錯性尤其重要。您可能還希望您的訓練任務是彈性的,例如,計算資源可以在任務過程中動態地加入和離開。
PyTorch 提供一個稱為 torchrun
的工具,它提供容錯和彈性訓練。發生故障時,torchrun
會記錄錯誤並嘗試從訓練任務上次儲存的「快照」自動重新啟動所有進程。
快照儲存的不僅僅是模型狀態;它還可以包括有關已運行的 epoch 數量、優化器狀態或訓練任務的任何其他有狀態屬性的詳細資訊,這些資訊對於其連續性是必要的。
為什麼使用 torchrun
¶
torchrun
處理分散式訓練的細節,因此您無需這樣做。例如,
您無需設定環境變數或明確傳遞
rank
和world_size
;torchrun
會將其與其他幾個環境變數一起分配。無需在您的腳本中呼叫
mp.spawn
;您只需要一個通用的main()
進入點,並使用torchrun
啟動腳本。這樣,相同的腳本可以在非分散式以及單節點和多節點設定中運行。從上次儲存的訓練快照優雅地重新啟動訓練。
優雅的重新啟動¶
為了實現優雅的重新啟動,您應該像這樣建構您的訓練腳本:
def main():
load_snapshot(snapshot_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_snapshot(snapshot_path)
如果發生故障,torchrun
將終止所有進程並重新啟動它們。每個進程進入點首先載入並初始化上次儲存的快照,然後從那裡繼續訓練。因此,在任何故障時,您只會失去上次儲存的快照中的訓練進度。
在彈性訓練中,每當有任何成員資格變更(新增或移除節點)時,torchrun
將終止並在可用裝置上產生進程。擁有此結構可確保您的訓練任務可以繼續,而無需手動干預。
multigpu.py v/s multigpu_torchrun.py 的差異
進程群組初始化¶
torchrun
會自動指定RANK
和WORLD_SIZE
,以及其他環境變數。
- def ddp_setup(rank, world_size):
+ def ddp_setup():
- """
- Args:
- rank: Unique identifier of each process
- world_size: Total number of processes
- """
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "12355"
- init_process_group(backend="nccl", rank=rank, world_size=world_size)
+ init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
使用 torchrun 提供的環境變數¶
- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])
儲存和載入快照¶
定期將所有相關資訊儲存在快照中,可讓我們的訓練作業在發生中斷後無縫恢復。
+ def _save_snapshot(self, epoch):
+ snapshot = {}
+ snapshot["MODEL_STATE"] = self.model.module.state_dict()
+ snapshot["EPOCHS_RUN"] = epoch
+ torch.save(snapshot, "snapshot.pt")
+ print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")
+ def _load_snapshot(self, snapshot_path):
+ snapshot = torch.load(snapshot_path)
+ self.model.load_state_dict(snapshot["MODEL_STATE"])
+ self.epochs_run = snapshot["EPOCHS_RUN"]
+ print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
在 Trainer 建構函式中載入快照¶
重新啟動中斷的訓練作業時,您的腳本會先嘗試載入快照,以從中恢復訓練。
class Trainer:
def __init__(self, snapshot_path, ...):
...
+ if os.path.exists(snapshot_path):
+ self._load_snapshot(snapshot_path)
...
恢復訓練¶
訓練可以從上次運行的 epoch 恢復,而不是從頭開始。
def train(self, max_epochs: int):
- for epoch in range(max_epochs):
+ for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
執行腳本¶
只需像對待非多處理腳本一樣呼叫您的進入點函式;torchrun
會自動產生進程。
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
- world_size = torch.cuda.device_count()
- mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+ main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10
延伸閱讀¶
使用 DDP 進行多節點訓練 (本系列中的下一個教學課程)
使用 DDP 進行多 GPU 訓練 (本系列中的上一個教學課程)