• 教學 >
  • 使用 Join Context Manager 進行具有不均勻輸入的分散式訓練
捷徑

使用 Join Context Manager 進行具有不均勻輸入的分散式訓練

建立於:2021 年 8 月 4 日 | 最後更新:2023 年 1 月 9 日 | 最後驗證:2024 年 11 月 5 日

作者Andrew Gu

注意

editgithub 中查看和編輯本教學。

注意

Join 在 PyTorch 1.10 中作為原型功能引入。此 API 可能會變更。

在本教學中,您將看到

  • Join context manager 的概述。

  • 如何將 context manager 與 DistributedDataParallel 一起使用的範例。

  • 如何將 context manager 與 DistributedDataParallelZeroRedundancyOptimizer 一起使用的範例。

  • 將關鍵字參數傳遞到 context manager 的範例。

  • 深入了解 Join context manager 的運作方式。

  • 展示如何使玩具類別與 context manager 相容的範例。

什麼是 Join

分散式資料平行入門 - 基本使用案例中,您看到了使用 DistributedDataParallel 執行資料平行訓練的一般框架。這隱含地在每個反向傳播中排程 all-reduces,以同步各個 rank 之間的梯度。這種 集體通信需要處理群組中所有 rank 的參與,因此如果一個 rank 的輸入較少,則其他 rank 將會掛起或出錯(取決於後端)。更廣泛地說,對於任何執行每次迭代同步集體通信的類別,此問題仍然存在。

Join 是一個 context manager,用於圍繞每個 rank 的訓練迴圈,以促進具有不均勻輸入的訓練。context manager 允許那些提早耗盡輸入的 rank(即提早加入)遮蔽那些尚未加入的 rank 所執行的集體通信。通信被遮蔽的方式由 hooks 指定。

JoinDistributedDataParallel 一起使用

PyTorch 的 DistributedDataParallel 可以與 Join context manager 搭配使用。以下是一個使用範例

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

這會產生以下輸出(其中來自 rank 0 和 rank 1 的 print() 的順序可能任意)

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

在引入此通用 Join context manager 之前,DistributedDataParallel 提供了自己的 join() context manager。在上面的範例中,使用 with Join([model]): 等同於使用 with model.join():。現有 DistributedDataParallel.join() 的一個限制是它不允許多個參與類別,例如 DistributedDataParallelZeroRedundancyOptimizer 一起使用。

JoinDistributedDataParallelZeroRedundancyOptimizer 一起使用

Join context manager 不僅可以與單個類別一起使用,還可以與多個類別一起使用。PyTorch 的 ZeroRedundancyOptimizer 也與 context manager 相容,因此在這裡,我們研究如何修改先前的範例以同時使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

這將產生與之前相同的輸出。值得注意的變化是將 ZeroRedundancyOptimizer 實例額外傳遞到 Join() 中。

傳遞關鍵字引數

類別可以提供關鍵字引數,在執行時修改它們在 context manager 中的行為。例如,DistributedDataParallel 提供了一個引數 divide_by_initial_world_size,它決定了梯度是否除以初始 world size,還是除以有效的 world size (即非加入 ranks 的數量)。這些關鍵字引數可以直接傳遞到 context manager 中。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

傳遞到 context manager 中的關鍵字引數會在所有參與的類別之間共享。這應該不是一個限制,因為我們不預期多個 Joinable 需要相同引數的不同設定的情況。儘管如此,這還是需要記住的一點。

Join 如何運作?

現在我們已經看到了一些使用 Join context manager 的初步範例,讓我們更深入地探討它是如何運作的。這將更深入地了解它提供的完整功能,並讓您準備好使您自己的自定義類別相容。在這裡,我們將介紹 Join 類別,以及支援類別 JoinableJoinHook

Joinable

首先,與 Join context manager 相容的類別必須繼承自抽象基底類別 Joinable。特別是,Joinable 必須實作

  • join_hook(self, **kwargs) -> JoinHook

這會傳回 JoinableJoinHook 實例,決定加入的 processes 應該如何 shadow 由 Joinable 執行的每次迭代的 collective communications。

  • join_device(self) -> torch.device

這會傳回一個 device,供 Join context manager 用於執行 collective communications,例如 torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

這會傳回 process group,供 Join context manager 用於執行 collective communications。

特別是,需要 join_devicejoin_process_group 屬性,以確保 context manager 可以排程 joined 和 non-joined processes 之間的 collective communications。一種用法是使用 all-reduce 來計算每次迭代中 non-joined processes 的數量。另一種用法是實作 throw_on_early_termination=True 所需的機制,我們將在稍後解釋。

DistributedDataParallelZeroRedundancyOptimizer 已經繼承自 Joinable 並實作了上述方法,這就是為什麼我們可以直接在先前的範例中使用它們的原因。

Joinable 類別應確保呼叫 Joinable 建構函式,因為它初始化了一個 JoinConfig 實例,context manager 在內部使用它來確保正確性。這將儲存在每個 Joinable 中,作為一個欄位 _join_config

JoinHook

接下來,讓我們分解 JoinHook 類別。JoinHook 提供了兩個進入 context manager 的入口點

  • main_hook(self) -> None

當存在尚未加入的 rank 時,每個 joined rank 會重複呼叫這個 hook。它旨在 shadow 由 Joinable 在每個訓練迭代中執行的 collective communications(例如,在一個正向傳遞、反向傳遞和優化器步驟中)。

  • post_hook(self, is_last_joiner: bool) -> None

當所有 ranks 都已加入時,會呼叫這個 hook 一次。它會傳遞一個額外的 bool 引數 is_last_joiner,它表示 rank 是否是最後加入的 ranks 之一。該引數對於同步可能很有用。

為了給出這些 hooks 可能看起來如何的具體範例,提供的 ZeroRedundancyOptimizer 主 hook 執行一個正常的優化器步驟,因為 joined rank 仍然負責更新和同步其參數 shard,並且提供的 DistributedDataParallel post-hook 從最後加入的 ranks 之一廣播最終更新的模型,以確保它在所有 ranks 之間都是相同的。

Join

最後,讓我們檢查一下這些如何放入 Join 類別本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我們在先前的範例中所看到的,建構函式接收一個參與訓練迴圈的 Joinable s 列表。這些應該是在每次迭代中執行 collective communications 的類別。

enable 是一個 bool 型別,如果您確定不會有不均勻的輸入,則可以將其設為 False。在這種情況下,上下文管理器會變得無效,類似於 contextlib.nullcontext()。這也可能會停用參與的 Joinable 中的 join 相關計算。

throw_on_early_termination 是一個 bool 型別,可以設定為 True,以便在偵測到不均勻的輸入時,每個 rank 都會引發例外。這對於不符合上下文管理器需求的情況很有用,這種情況通常發生在來自不同類別的 collective communications 可能會任意交錯時,例如在使用 DistributedDataParallel 時,模型中具有 SyncBatchNorm 層。在這種情況下,應將此參數設定為 True,以便應用程式邏輯可以捕獲例外,並決定如何繼續。

  • 核心邏輯發生在 __exit__() 方法中,該方法在存在未加入的 rank 時迴圈執行,調用每個 Joinable 的主要 hook,然後在所有 rank 都加入後,調用它們的 post hook。主要 hook 和 post-hook 均以 Joinable 傳入的順序進行迭代。

  • 上下文管理器需要來自未加入程序的心跳訊號。因此,每個 Joinable 類別都應該在其每次迭代的 collective communications 之前調用 Join.notify_join_context()。上下文管理器將確保只有第一個傳入的 Joinable 實際發送心跳訊號。

警告

如上文關於 throw_on_early_termination 所述,Join 上下文管理器與某些類別的組合不相容。JoinableJoinHook 必須是可序列化的,因為每個 hook 都在繼續下一個 hook 之前完全執行。換句話說,兩個 hook 不能重疊。此外,目前,主要 hook 和 post-hook 均以相同的確定性順序進行迭代。如果這似乎是一個主要的限制,我們可以修改 API 以允許自訂排序。

使 Toy Class 與 Join 一起工作

由於前一節介紹了幾個概念,讓我們透過一個 toy 範例在實踐中看看它們。在這裡,我們將實作一個類別,用於計算在其 rank 加入之前所有 rank 上看到的輸入數量。這應該為您提供一個基本的概念,說明如何使您自己的類別與 Join 上下文管理器相容。

具體來說,以下程式碼會讓每個 rank 列印出 (1) 在其加入之前所有 rank 上看到的輸入數量,以及 (2) 所有 rank 上的總輸入數量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由於 rank 0 看到 5 個輸入,而 rank 1 看到 6 個輸入,因此會產生以下輸出

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

需要強調的幾個關鍵點

  • 一個 Counter 實例每次迭代執行一次 all-reduce,因此主要 hook 也執行一次 all-reduce 以 shadows 它。

  • Counter 類別在其 __call__() 方法的開頭調用 Join.notify_join_context(),因為那是其每次迭代的 collective communications(即其 all-reduce)之前的位置。

  • is_last_joiner 參數用於確定 post-hook 中的廣播源。

  • 我們將 sync_max_count 關鍵字參數傳遞給上下文管理器,然後將其轉發到 Counter 的 join hook。

文件

存取 PyTorch 的全面開發人員文件

查看文件

教學

取得初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並取得您的問題解答

查看資源