捷徑

使用 ZeroRedundancyOptimizer 分片 Optimizer 狀態

建立於:2021 年 2 月 26 日 | 最後更新:2021 年 10 月 20 日 | 最後驗證:未驗證

在本食譜中,您將學習

需求

什麼是 ZeroRedundancyOptimizer

ZeroRedundancyOptimizer 的概念來自 DeepSpeed/ZeRO 專案Marian,它們在分散式資料平行處理程序中分片 optimizer 狀態,以減少每個處理程序的記憶體佔用量。在分散式資料平行入門教學中,我們展示了如何使用 DistributedDataParallel (DDP) 來訓練模型。在該教學中,每個處理程序都保留一個專用的 optimizer 副本。由於 DDP 已經在反向傳播中同步梯度,因此所有 optimizer 副本將在每次迭代中對相同的參數和梯度值進行操作,這就是 DDP 保持模型副本處於相同狀態的方式。通常,optimizer 也會維護本機狀態。例如,Adam optimizer 使用每個參數的 exp_avgexp_avg_sq 狀態。因此,Adam optimizer 的記憶體消耗至少是模型大小的兩倍。 鑑於此觀察,我們可以透過在 DDP 處理程序之間分片 optimizer 狀態來減少 optimizer 記憶體佔用量。更具體地說,不是為所有參數建立每個參數的狀態,而是不同 DDP 處理程序中的每個 optimizer 實例僅保留所有模型參數的分片的 optimizer 狀態。 optimizer step() 函數僅更新其分片中的參數,然後將其更新的參數廣播到所有其他同級 DDP 處理程序,以便所有模型副本仍然處於相同的狀態。

如何使用 ZeroRedundancyOptimizer

下面的程式碼示範了如何使用 ZeroRedundancyOptimizer。大部分程式碼與 分散式資料平行說明中提供的簡單 DDP 範例相似。主要區別在於 example 函數中的 if-else 子句,它包裝了 optimizer 建構,在 ZeroRedundancyOptimizerAdam optimizer 之間切換。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

輸出如下所示。 當使用 Adam 啟用 ZeroRedundancyOptimizer 時,optimizer step() 峰值記憶體消耗是 vanilla Adam 記憶體消耗的一半。 這與我們的預期一致,因為我們正在兩個處理程序之間分片 Adam optimizer 狀態。 輸出還顯示,使用 ZeroRedundancyOptimizer,模型參數在一次迭代後仍然以相同的值結束(參數總和在使用和不使用 ZeroRedundancyOptimizer 的情況下是相同的)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源