• 教學 >
  • 分散式檢查點 (DCP) 入門
捷徑

分散式檢查點 (DCP) 入門

建立於:2023 年 10 月 02 日 | 最後更新:2024 年 10 月 30 日 | 最後驗證:2024 年 11 月 05 日

作者Iris ZhangRodrigo KumperaChien-Chin HuangLucas Pasqualin

注意

editgithub 中檢視和編輯此教學。

先決條件

在分散式訓練期間檢查 AI 模型可能具有挑戰性,因為參數和梯度在訓練器之間進行分割,並且當您恢復訓練時,可用的訓練器數量可能會發生變化。 Pytorch 分散式檢查點 (DCP) 可以幫助簡化此過程。

在本教學中,我們展示如何將 DCP API 與簡單的 FSDP 包裝模型一起使用。

DCP 如何運作

torch.distributed.checkpoint() 允許從多個排名並行儲存和載入模型。您可以使用此模組並行儲存到任意數量的排名,然後在載入時重新分片到不同的叢集拓撲。

此外,透過使用 torch.distributed.checkpoint.state_dict() 中的模組,DCP 提供了在分散式設定中優雅處理 state_dict 產生和載入的支援。這包括管理跨模型和最佳化器的完全限定名稱 (FQN) 映射,以及設定 PyTorch 提供的平行處理的預設參數。

DCP 在一些重要方面與 torch.save()torch.load() 不同

  • 它為每個檢查點產生多個檔案,每個排名至少一個。

  • 它以原地方式運作,這表示模型應首先分配其資料,而 DCP 會改為使用該儲存空間。

  • DCP 提供對 Stateful 物件的特殊處理(正式定義在 torch.distributed.checkpoint.stateful 中),如果定義了 state_dictload_state_dict 方法,則會自動呼叫它們。

注意

本教學中的程式碼在 8-GPU 伺服器上執行,但可以輕鬆地推廣到其他環境。

如何使用 DCP

在這裡,我們使用以 FSDP 包裝的玩具模型來進行示範。同樣地,API 和邏輯可以應用於更大的模型以進行檢查點。

儲存

現在,讓我們建立一個玩具模組,用 FSDP 包裝它,用一些虛擬輸入資料餵給它,然後儲存它。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

請繼續檢查 checkpoint 目錄。您應該會看到 8 個檢查點檔案,如下所示。

Distributed Checkpoint

載入

儲存後,讓我們建立相同的 FSDP 包裝模型,並將儲存的狀態字典從儲存空間載入到模型中。您可以載入相同的 world size 或不同的 world size。

請注意,在載入之前,您必須先呼叫 model.state_dict(),並將其傳遞給 DCP 的 load_state_dict() API。 這與 torch.load() 截然不同,因為 torch.load() 僅需要檢查點的路徑即可載入。 我們需要在載入之前取得 state_dict 的原因是:

  • DCP 使用來自模型 state_dict 的預先分配儲存空間,從檢查點目錄載入。 在載入期間,傳入的 state_dict 會被就地更新。

  • DCP 需要從模型取得分片資訊,才能支援重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果您想將儲存的檢查點載入到非 FSDP 包裝的模型中,例如用於推論的非分散式設定,您也可以使用 DCP 做到這一點。 預設情況下,DCP 以單程式多資料 (SPMD) 樣式儲存和載入分散式 state_dict。 但是,如果沒有初始化任何程序群組,DCP 會推斷其意圖是以「非分散式」樣式儲存或載入,這表示完全在目前的程序中進行。

注意

對多程式多資料的分散式檢查點支援仍在開發中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式

尚未提及的一個缺點是,DCP 以與使用 torch.save 產生的格式本質上不同的格式儲存檢查點。 當使用者希望與習慣 torch.save 格式的使用者分享模型時,或者只是想為其應用程式增加格式靈活性時,這可能會成為問題。 在這種情況下,我們在 torch.distributed.checkpoint.format_utils 中提供了 format_utils 模組。

為方便使用者,我們提供了一個命令列公用程式,其格式如下:

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上面的命令中,modetorch_to_dcpdcp_to_torch 其中之一。

或者,也為可能希望直接轉換檢查點的使用者提供了方法。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
dcp_to_torch_save(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

結論

總而言之,我們已經了解如何使用 DCP 的 save()load() API,以及它們與 torch.save()torch.load() 的不同之處。 此外,我們還學習了如何使用 get_state_dict()set_state_dict() 在狀態字典生成和載入期間自動管理特定於平行處理的 FQN 和預設值。

如需更多資訊,請參閱以下內容:


為本教學課程評分

© Copyright 2024, PyTorch。

使用 Sphinx 建構,主題由 theme 提供,由 Read the Docs 提供。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源