• 教學 >
  • 使用分散式檢查點 (DCP) 進行非同步儲存
捷徑

使用分散式檢查點 (DCP) 進行非同步儲存

建立於:2024 年 7 月 22 日 | 最後更新:2024 年 7 月 22 日 | 最後驗證:2024 年 11 月 05 日

作者: Lucas Pasqualin, Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang

檢查點通常是分散式訓練工作負載的關鍵路徑中的瓶頸,隨著模型和世界規模的增長,會產生越來越大的成本。一種抵消此成本的絕佳策略是以平行非同步方式建立檢查點。以下,我們擴展了來自 分散式檢查點入門教學 的儲存範例,以展示如何輕鬆地將其與 torch.distributed.checkpoint.async_save 整合。

您將學到什麼
  • 如何使用 DCP 以平行方式產生檢查點

  • 優化效能的有效策略

先決條件

非同步檢查點概觀

在開始使用非同步檢查點之前,了解它與同步檢查點的差異和限制非常重要。具體來說

  • 記憶體需求 - 非同步檢查點的工作方式是首先將模型複製到內部 CPU 緩衝區。

    這很有用,因為它可以確保在模型仍在建立檢查點時,模型和優化器的權重不會改變,但會將 CPU 記憶體提高 checkpoint_size_per_rank X number_of_ranks 倍。此外,使用者應注意了解其系統的記憶體限制。具體來說,鎖頁記憶體表示使用 page-lock 記憶體,與 pageable 記憶體相比,它可能很稀缺。

  • 檢查點管理 - 由於檢查點是非同步的,因此由使用者管理並行執行的檢查點。一般來說,使用者可以

    透過處理從 async_save 傳回的 future 物件來採用他們自己的管理策略。對於大多數使用者,我們建議將檢查點限制為一次一個非同步請求,以避免每個請求產生額外的記憶體壓力。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        optimizer.step()

        # waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
        if checkpoint_future is not None:
            checkpoint_future.result()

        state_dict = { "app": AppState(model, optimizer) }
        checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running async checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

使用鎖頁記憶體獲得更高的效能

如果上述優化仍然不夠有效,您可以利用 GPU 模型的額外優化,該優化利用鎖頁記憶體緩衝區進行檢查點暫存。具體來說,此優化解決了非同步檢查點的主要開銷,即記憶體內複製到檢查點緩衝區。透過在檢查點請求之間維護鎖頁記憶體緩衝區,使用者可以利用直接記憶體存取來加速此複製。

注意

此優化的主要缺點是在檢查點步驟之間緩衝區的持久性。如果沒有鎖頁記憶體優化(如上所示),任何檢查點緩衝區都會在檢查點完成後立即釋放。透過鎖頁記憶體實作,此緩衝區會在步驟之間維護,導致在應用程式生命週期中維持相同的峰值記憶體壓力。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint import StorageWriter

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    # The storage writer defines our 'staging' strategy, where staging is considered the process of copying
    # checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
    # into a persistent buffer with pinned memory enabled.
    # Note: It's important that the writer persists in between checkpointing requests, since it maintains the
    # pinned memory buffer.
    writer = StorageWriter(cached_state_dict=True)
    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        optimizer.step()

        state_dict = { "app": AppState(model, optimizer) }
        if checkpoint_future is not None:
            # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
            checkpoint_future.result()
        dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

結論

總之,我們學習了如何使用 DCP 的 async_save() API 在關鍵訓練路徑之外產生檢查點。我們也了解了使用此 API 引入的額外記憶體和並行開銷,以及利用鎖頁記憶體來進一步加速的額外優化。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源