通用 Join Context Manager¶
通用 join context manager 有助於在不均勻的輸入上進行分散式訓練。 本頁概述了相關類別的 API:Join
、Joinable
和 JoinHook
。 如需教學課程,請參閱使用 Join Context Manager 進行具有不均勻輸入的分散式訓練。
- class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source][source]¶
這個類別定義了通用的 join context manager,允許在程序加入 (join) 後呼叫自定義的 hook。
這些 hook 應該遮蔽 (shadow) 未加入程序的 collective communications,以防止掛起和出錯,並確保演算法的正確性。有關 hook 定義的詳細信息,請參閱
JoinHook
。警告
此 context manager 需要每個參與的
Joinable
在其自身每次迭代的 collective communications 之前呼叫方法notify_join_context()
,以確保正確性。警告
此 context manager 要求
JoinHook
物件中的所有process_group
屬性都是相同的。 如果有多個JoinHook
物件,則使用第一個物件的device
。 處理群組和裝置資訊用於檢查未加入的程序,以及通知程序在啟用throw_on_early_termination
時拋出例外,這兩者都使用 all-reduce。- 參數
範例
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
- static notify_join_context(joinable)[source][source]¶
通知 join context manager 呼叫程序尚未加入 (join)。
然後,如果
throw_on_early_termination=True
,檢查是否已偵測到不均勻輸入 (即如果一個程序已經加入),如果是,則拋出例外。此方法應從
Joinable
物件的每次迭代的 collective communications 之前呼叫。 例如,這應該在DistributedDataParallel
中 forward pass 的開頭呼叫。只有傳遞到 context manager 的第一個
Joinable
物件會在此方法中執行 collective communications,對於其他物件,此方法是空洞的 (vacuous)。
- class torch.distributed.algorithms.Joinable[source][source]¶
這定義了 joinable 類別的抽象基底類別。
一個 joinable 類別(繼承自
Joinable
)應該實現join_hook()
,它返回一個JoinHook
實例,此外還需實現join_device()
和join_process_group()
,它們分別返回裝置和處理群組資訊。