捷徑

訓練腳本

如果您的訓練腳本與 torch.distributed.launch 搭配使用,它將繼續與 torchrun 搭配使用,但有以下差異

  1. 無需手動傳遞 RANKWORLD_SIZEMASTER_ADDRMASTER_PORT

  2. 可以提供 rdzv_backendrdzv_endpoint。 對於大多數使用者來說,這將設定為 c10d (請參閱 rendezvous)。 預設的 rdzv_backend 建立一個非彈性的 rendezvous,其中 rdzv_endpoint 保留 master 位址。

  3. 請確保您的腳本中有 load_checkpoint(path)save_checkpoint(path) 的邏輯。 當任何數量的 worker 失敗時,我們會使用相同的程式引數重新啟動所有 worker,因此您最多會遺失到最近檢查點的進度 (請參閱彈性啟動)。

  4. use_env 標誌已被移除。 如果您之前是通過解析 --local-rank 選項來解析 local rank,則需要從環境變數 LOCAL_RANK 獲取 local rank (例如 int(os.environ["LOCAL_RANK"]))。

下面是一個示範性的訓練腳本範例,該腳本在每個 epoch 進行檢查點儲存,因此在發生故障時,最糟情況下損失的進度是一個完整 epoch 的訓練。

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

有關符合 torchelastic 規範的訓練腳本的具體範例,請訪問我們的範例頁面。

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources