• 文件 >
  • 通用 Join Context Manager
快捷鍵

通用 Join Context Manager

通用 join context manager 有助於在不均勻的輸入上進行分散式訓練。 本頁概述了相關類別的 API:JoinJoinableJoinHook。 如需教學課程,請參閱使用 Join Context Manager 進行具有不均勻輸入的分散式訓練

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source][source]

這個類別定義了通用的 join context manager,允許在程序加入 (join) 後呼叫自定義的 hook。

這些 hook 應該遮蔽 (shadow) 未加入程序的 collective communications,以防止掛起和出錯,並確保演算法的正確性。有關 hook 定義的詳細信息,請參閱 JoinHook

警告

此 context manager 需要每個參與的 Joinable 在其自身每次迭代的 collective communications 之前呼叫方法 notify_join_context(),以確保正確性。

警告

此 context manager 要求 JoinHook 物件中的所有 process_group 屬性都是相同的。 如果有多個 JoinHook 物件,則使用第一個物件的 device。 處理群組和裝置資訊用於檢查未加入的程序,以及通知程序在啟用 throw_on_early_termination 時拋出例外,這兩者都使用 all-reduce。

參數
  • joinables (List[Joinable]) – 參與的 Joinable 物件的列表;它們的 hook 按給定的順序迭代。

  • enable (bool) – 一個標誌,用於啟用不均勻輸入的偵測;設定為 False 會停用 context manager 的功能,只有在使用者知道輸入不會不均勻時才應設定 (預設值:True)。

  • throw_on_early_termination (bool) – 一個標誌,控制在偵測到不均勻輸入時是否拋出例外 (預設值:False)。

範例

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
static notify_join_context(joinable)[source][source]

通知 join context manager 呼叫程序尚未加入 (join)。

然後,如果 throw_on_early_termination=True,檢查是否已偵測到不均勻輸入 (即如果一個程序已經加入),如果是,則拋出例外。

此方法應從 Joinable 物件的每次迭代的 collective communications 之前呼叫。 例如,這應該在 DistributedDataParallel 中 forward pass 的開頭呼叫。

只有傳遞到 context manager 的第一個 Joinable 物件會在此方法中執行 collective communications,對於其他物件,此方法是空洞的 (vacuous)。

參數

joinable (Joinable) – 呼叫此方法的 Joinable 物件。

返回

如果 joinable 是傳遞到 context manager 的第一個物件,則返回一個 async work handle,用於通知 context manager 該程序尚未加入;否則返回 None

class torch.distributed.algorithms.Joinable[source][source]

這定義了 joinable 類別的抽象基底類別。

一個 joinable 類別(繼承自 Joinable)應該實現 join_hook(),它返回一個 JoinHook 實例,此外還需實現 join_device()join_process_group(),它們分別返回裝置和處理群組資訊。

abstract property join_device: device

傳回用於執行連接上下文管理器所需的集體通訊的裝置。

abstract join_hook(**kwargs)[source][source]

針對給定的 Joinable,傳回 JoinHook 實例。

參數

kwargs (dict) – 一個 dict,包含用於修改執行時連接鉤子行為的任何關鍵字引數;所有共享相同連接上下文管理器的 Joinable 實例都會被傳遞相同的 kwargs 值。

回傳型別

JoinHook

abstract property join_process_group: Any

傳回連接上下文管理器本身所需的集體通訊的進程組。

class torch.distributed.algorithms.JoinHook[source][source]

這定義了一個連接鉤子,它在連接上下文管理器中提供兩個入口點。

入口點:一個主鉤子,在存在未連接進程時重複呼叫,以及一個後鉤子,在所有進程都已連接後呼叫一次。

要為通用連接上下文管理器實現連接鉤子,請定義一個繼承自 JoinHook 的類別,並根據需要覆寫 main_hook()post_hook()

main_hook()[source][source]

當存在未連接的進程時,呼叫此鉤子以遮蔽訓練迭代中的集體通訊。

訓練迭代,即在一次正向傳遞、反向傳遞和優化器步驟中。

post_hook(is_last_joiner)[source][source]

在所有進程都已連接後呼叫鉤子。

它會傳遞一個額外的 bool 引數 is_last_joiner,用於指示該 rank 是否為最後加入的其中之一。

參數

is_last_joiner (bool) – 如果該 rank 是最後加入的其中之一,則為 True;否則為 False

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources