捷徑

DDP 通訊掛鉤

DDP 通訊掛鉤是一個通用介面,用於控制如何在 worker 之間通訊梯度,方法是覆寫 DistributedDataParallel 中的 vanilla allreduce。提供了一些內建的通訊掛鉤,使用者可以輕鬆地應用任何這些掛鉤來最佳化通訊。此外,掛鉤介面還可以支援使用者定義的通訊策略,以用於更進階的使用案例。

如何使用通訊掛鉤?

若要使用通訊掛鉤,使用者只需讓 DDP 模型在訓練迴圈之前註冊該掛鉤,如下所示。

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

通訊掛鉤在什麼上運作?

通訊掛鉤 (communication hook) 提供了一種彈性的方式來執行 allreduce 梯度。因此,它主要在 allreduce 之前作用於每個副本上的梯度,這些梯度會被分桶 (bucketized) 以增加通訊和計算之間的重疊。特別地,torch.distributed.GradBucket 代表一個要進行 allreduce 的梯度張量桶。

class torch.distributed.GradBucket

這個類別主要將扁平化的梯度張量 (由 buffer() 返回) 傳遞給 DDP 通訊掛鉤。此張量可以進一步分解為此 bucket 內每個參數的張量列表 (由 get_per_parameter_tensors() 返回) 以應用逐層 (layer-wise) 操作。

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int

警告

由於 bucket 在第一次迭代後會重建,因此不應依賴訓練開始時的索引。

回傳

儲存幾個連續層梯度的 bucket 索引。所有梯度都被分桶。

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor
回傳

一個扁平化的 1D torch.Tensor buffer,可以進一步分解為此 bucket 內每個參數的張量列表。

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
回傳

一個 torch.Tensor 的列表。列表中的每個張量對應於一個梯度。

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool
回傳

這個 bucket 是否是在一次迭代中最後一個要進行 allreduce 的 bucket。這也意味著這個 bucket 對應於前向傳遞 (forward pass) 中的前幾層。

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None

將 bucket 中的張量替換為輸入張量 buffer。

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
回傳

一個 torch.Tensor 的列表。列表中的每個張量對應於一個模型參數。

預設通訊掛鉤

預設的通訊掛鉤是簡單的**無狀態 (stateless)** 掛鉤,因此 register_comm_hook 中的輸入狀態要麼是一個進程組 (process group),要麼是 None。輸入的 bucket 是一個 torch.distributed.GradBucket 物件。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[原始碼][原始碼]

使用 GradBucket 張量呼叫 allreduce

一旦梯度張量在所有 worker 上聚合,它的 then 回調 (callback) 就會取平均值並返回結果。

如果使用者註冊了這個 DDP 通訊 hook,預期的 DDP 結果會與未註冊 hook 的情況相同。因此,這不會改變 DDP 的行為,使用者可以將其用作參考,或者修改此 hook 以記錄有用的資訊或用於任何其他目的,而不會影響 DDP 的行為。

範例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
回傳類型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[原始碼][原始碼]

透過將 GradBucket 轉換為 torch.float16 並除以 process group 大小來壓縮。

這個 DDP 通訊 hook 實現了一種簡單的梯度壓縮方法,它將 GradBucket 張量轉換為半精度浮點格式 (torch.float16),然後將其除以 process group 的大小。它 allreduce 這些 float16 梯度張量。一旦壓縮後的梯度張量被 allreduce,鏈式回呼 decompress 會將其轉換回輸入資料類型 (例如 float32)。

範例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
回傳類型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[原始碼][原始碼]

警告:這個 API 是實驗性的,並且需要 NCCL 版本高於 2.9.6。

這個 DDP 通訊 hook 實現了一種簡單的梯度壓縮方法,它將 GradBucket 張量轉換為半精度 Brain 浮點格式 (torch.bfloat16),然後將其除以 process group 的大小。它 allreduce 這些 bfloat16 梯度張量。一旦壓縮後的梯度張量被 allreduce,鏈式回呼 decompress 會將其轉換回輸入資料類型 (例如 float32)。

範例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
回傳類型

Future[Tensor]

此外,還提供了一個通訊 hook 包裝器,以支援 fp16_compress_hook()bf16_compress_hook() 作為包裝器,它可以與其他通訊 hook 結合使用。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[原始碼][原始碼]

將輸入張量轉換為 torch.float16,將 hook 的結果轉換回輸入 dtype。

這個包裝器將給定 DDP 通訊 hook 的輸入梯度張量轉換為半精度浮點格式 (torch.float16),並將給定 hook 的結果張量轉換回輸入資料類型,例如 float32。因此,fp16_compress_hook 等同於 fp16_compress_wrapper(allreduce_hook)

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
回傳類型

Callable[[Any, GradBucket], Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[原始碼][原始碼]

警告:這個 API 是實驗性的,並且需要 NCCL 版本高於 2.9.6。

這個包裝器將給定 DDP 通訊 hook 的輸入梯度張量轉換為半精度 Brain 浮點格式 <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16`),並將給定 hook 的結果張量轉換回輸入資料類型,例如 float32

因此,bf16_compress_hook 等同於 bf16_compress_wrapper(allreduce_hook)

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
回傳類型

Callable[[Any, GradBucket], Future[Tensor]]

PowerSGD 通訊 Hook

PowerSGD (Vogels et al., NeurIPS 2019) 是一種梯度壓縮演算法,它可以提供非常高的壓縮率並加速頻寬受限的分散式訓練。 這個演算法需要維護一些超參數和內部狀態。 因此,PowerSGD 通訊 hook 是一個有狀態的 hook,使用者需要提供一個如下定義的狀態物件。

PowerSGD 狀態

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source][source]

儲存訓練期間所有梯度之演算法超參數與內部狀態。

特別地,matrix_approximation_rankstart_powerSGD_iter 是使用者應調整的主要超參數。為提升效能,我們建議保持二元超參數 use_error_feedbackwarm_start 啟用狀態。

  1. matrix_approximation_rank 控制壓縮後的低秩張量大小,其決定了壓縮率。秩越低,壓縮程度越高。

    1.1. 若 matrix_approximation_rank 太低,完整的模型品質將需要更多訓練步驟才能達到,或者永遠無法達到,並導致準確性降低。

    1.2. 增加 matrix_approximation_rank 會大幅增加壓縮的計算成本,且準確性可能不會在超過特定 matrix_approximation_rank 閾值後進一步提升。

為了調整 matrix_approximation_rank,我們建議從 1 開始,並以 2 的倍數增加 (例如指數網格搜尋,1, 2, 4, …),直到達到令人滿意的準確性。通常只會使用一個小值 1-4。對於某些 NLP 任務 (如原始論文的附錄 D 所示),此值已增加到 32。

  1. start_powerSGD_iter 會延遲 PowerSGD 壓縮,直到步驟 start_powerSGD_iter,且 vanilla allreduce 會在步驟 start_powerSGD_iter 之前運行。這種 vanilla allreduce + PowerSGD 的混合方案可以有效地提高準確性,即使使用相對較小的 matrix_approximation_rank。這是因為訓練階段的開始通常對不準確的梯度非常敏感,並且過早壓縮梯度可能會使訓練快速進入次優軌跡,這可能對準確性產生不可挽回的影響。

為了調整 start_powerSGD_iter,我們建議從總訓練步驟的 10% 開始,並增加它直到達到令人滿意的準確性。如果在訓練中有熱身階段,則 start_powerSGD_iter 通常應不小於熱身步驟的數量。

  1. min_compression_rate 是壓縮圖層時所需的最低壓縮率。由於壓縮產生的計算開銷,僅當可以充分節省頻寬時,張量才值得壓縮,其中 (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果無法滿足指定的壓縮率閾值,則將直接 allreduce 張量而不進行壓縮。

一旦 PowerSGD 壓縮開始,壓縮統計資料就會每 compression_stats_logging_frequency 次迭代記錄一次。

  1. orthogonalization_epsilon 可以是一個非常小的值 (例如,1e-8),添加到正交化步驟中的每個正規化矩陣列,以防止任何列全為 0 時出現除以零的錯誤。如果已經可以避免這種情況 (例如,通過批次正規化),建議使用 epsilon 值 0 以提高準確性。

  2. batch_tensors_with_same_shape 控制是否以批次操作壓縮和解壓縮具有相同形狀的張量,以實現更高的並行性。請注意,您還應該增加 bucket 大小 (即 DDP 建構子中的 bucket_cap_mb 參數),以使更多相同形狀的張量出現在同一個 bucket 中,但是這可能會減少計算和通訊之間的重疊,並且由於堆疊相同形狀的張量而增加記憶體佔用。如果壓縮/解壓縮計算是瓶頸,則設定為 True

警告

如果啟用錯誤回饋或熱身,則 DDP 中允許的 start_powerSGD_iter 的最小值為 2。這是因為 DDP 中還有另一個內部最佳化,會在迭代 1 時重建 buckets,這可能會與在重建過程之前記憶的任何張量衝突。

PowerSGD Hooks

警告

PowerSGD 通常需要與模型梯度大小相同的額外記憶體,以啟用錯誤回饋,這可以補償有偏差的壓縮通訊並提高準確性。

警告

PowerSGD hooks 可能與 Apex 自動混合精度套件 衝突。請改用 PyTorch 原生自動混合精度套件

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[原始碼][原始碼]

實作 PowerSGD 演算法。

這個 DDP 通訊 Hook 實作了 PowerSGD 梯度壓縮演算法,如 論文 中所述。一旦梯度張量在所有 worker 之間聚合,此 Hook 會套用以下壓縮方式:

  1. 將展平的 1D 梯度張量視為每個參數張量的列表,並將所有張量分成兩組

    1.1 應該在 allreduce 之前壓縮的張量,因為壓縮可以在頻寬上節省足夠的量。

    1.2 其餘的張量將直接進行 allreduce 而不進行壓縮,包括所有向量張量(用於偏差)。

  2. 處理未壓縮的張量

    2.1. 為這些未壓縮的張量分配連續的記憶體,並以批次方式 allreduce 所有未壓縮的張量,而無需壓縮;

    2.2. 將個別的未壓縮張量從連續的記憶體複製回輸入張量。

  3. 處理應使用 PowerSGD 壓縮進行壓縮的張量

    3.1. 對於每個張量 M,建立兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準常態分佈初始化並進行正交化;

    3.2. 計算 Ps 中的每個 P,它等於 MQ;

    3.3. 以批次方式 Allreduce Ps;

    3.4. 正交化 Ps 中的每個 P;

    3.5. 計算 Qs 中的每個 Q,它近似等於 M^TP;

    3.6. 以批次方式 Allreduce Qs;

    3.7. 計算所有壓縮張量中的每個 M,它近似等於 PQ^T。

請注意,此通訊 Hook 對於前 state.start_powerSGD_iter 次迭代強制執行 vanilla allreduce。這不僅讓使用者更能控制加速和準確性之間的權衡,還有助於為未來的通訊 Hook 開發人員抽象化 DDP 內部優化的一些複雜性。

參數
  • state (PowerSGDState) – 用於配置壓縮率並支援錯誤回饋、warm start 等的狀態資訊。若要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • bucket (dist.GradBucket) – 儲存 1D 展平梯度張量的 Bucket,該張量以批次處理多個每個變數的張量。請注意,由於 DDP 通訊 Hook 僅支援單進程單裝置模式,因此此 Bucket 中僅儲存一個張量。

回傳

通訊的未來處理程式,它會就地更新梯度。

回傳類型

Future[Tensor]

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[原始碼][原始碼]

實作簡化的 PowerSGD 演算法。

這個 DDP 通訊 Hook 實作了簡化的 PowerSGD 梯度壓縮演算法,如 論文 中所述。此變體不會逐層壓縮梯度,而是壓縮批次處理所有梯度的展平輸入張量。因此,它比 powerSGD_hook() 更快,但通常會導致 準確性低很多,除非 matrix_approximation_rank 為 1。

警告

增加此處的 matrix_approximation_rank 不一定會提高準確性,因為在沒有列/行對齊的情況下批次處理每個參數的張量可能會破壞低秩結構。因此,使用者應始終首先考慮 powerSGD_hook(),並且只有在 matrix_approximation_rank 為 1 時可以實現令人滿意的準確性時,才考慮此變體。

一旦梯度張量在所有 worker 之間聚合,此 Hook 會套用以下壓縮方式:

  1. 將展平的 1D 梯度張量視為具有 0 填充的正方形張量 M;

  2. 建立兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準常態分佈初始化並進行正交化;

  3. 計算 P,它等於 MQ;

  4. Allreduce P;

  5. 正交化 P;

  6. 計算 Q,它近似等於 M^TP;

  7. Allreduce Q;

  8. 計算 M,它近似等於 PQ^T。

  9. 將輸入張量截斷為原始長度。

請注意,此通訊 Hook 對於前 state.start_powerSGD_iter 次迭代強制執行 vanilla allreduce。這不僅讓使用者更能控制加速和準確性之間的權衡,還有助於為未來的通訊 Hook 開發人員抽象化 DDP 內部優化的一些複雜性。

參數
  • state (PowerSGDState) – 用於配置壓縮率並支援錯誤回饋、warm start 等的狀態資訊。若要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_iter

  • bucket (dist.GradBucket) – 儲存 1D 展平梯度張量的 Bucket,該張量以批次處理多個每個變數的張量。請注意,由於 DDP 通訊 Hook 僅支援單進程單裝置模式,因此此 Bucket 中僅儲存一個張量。

回傳

通訊的未來處理程式,它會就地更新梯度。

回傳類型

Future[Tensor]

範例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

偵錯通訊 Hook

顧名思義,偵錯通訊 Hook 用於偵錯和效能最佳化目的。

警告

偵錯通訊 Hook 不一定會輸出正確的結果。

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[原始碼][原始碼]

傳回一個包裝輸入的 future,因此它是一個不產生任何通訊額外負荷的無操作。

這個 Hook 應用於 Allreduce 最佳化的 headroom 分析,而不是正常的梯度同步。例如,如果在此 Hook 註冊後,僅能觀察到少於 10% 的訓練時間加速,這通常表示 Allreduce 不是這種情況下的效能瓶頸。如果無法輕易擷取 GPU 追蹤,或追蹤分析受到某些因素(例如 Allreduce 和計算之間的重疊,或跨 rank 的非同步)的複雜化,則這種檢測特別有用。

範例:
>>> ddp_model.register_comm_hook(None, noop_hook)
回傳類型

Future[Tensor]

通訊 Hook 的檢查點 (Checkpointing)

具狀態的通訊 hook 可以儲存為模型檢查點的一部分,以啟用訓練器重新啟動。為了使 hook 可序列化,應該定義 __setstate____getstate__

警告

__getstate__ 應該從返回的字典中排除不可序列化的屬性。

警告

__setstate__ 應該正確初始化從提供的 state 中排除的不可序列化的屬性。

PowerSGDState 已經實作了 __setstate____getstate__,可以作為參考。

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source][source]
__getstate__()[source][source]

返回一個 Dict[str, Any],它將被 pickle 並儲存。

process_group 不可序列化,並且從返回的狀態中排除。

__setstate__(state)[source][source]

獲取提供的 state 並將其設定為此 PowerSGDState 實例。

process_group 被設定為預設值。

這是一個儲存和重新載入 PowerSGD 狀態和 hook 的簡單端到端範例。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

致謝

非常感謝 PowerSGD 論文作者 Thijs Vogels 對 PowerSGD 通訊 hook 的程式碼審查,以及 比較實驗,這些實驗表明 PowerSGD 通訊 hook 的效能與原始 論文中的實作相當。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源