• 教學 >
  • 使用完全分片資料平行 (FSDP) 的進階模型訓練
捷徑

使用完全分片資料平行 (FSDP) 的進階模型訓練

建立於:2024 年 10 月 31 日 | 最後更新:2024 年 10 月 31 日 | 最後驗證:2024 年 11 月 05 日

作者Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

您將學到什麼
  • PyTorch 的完全分片資料平行模組:用於在資料平行工作者之間分片模組參數的封裝器。

資料平行工作者。

先決條件
  • PyTorch 1.12 或更高版本

  • 閱讀關於 FSDP API 的資訊。

本教學介紹了 PyTorch 1.12 版本中完全分片資料平行 (FSDP) 的更多進階功能。 若要熟悉 FSDP,請參閱 FSDP 入門教學

在本教學中,我們將使用 FSDP 微調 HuggingFace (HF) T5 模型,以文字摘要為工作範例。

此範例使用 Wikihow,為了簡化起見,我們將展示在具有 8 個 A100 GPU 的單一節點 P4dn 執行個體上進行訓練。 我們現在有幾篇部落格文章((link1), (link2))和一篇關於多節點叢集上大規模 FSDP 訓練的 論文

FSDP 是一個生產就緒的套件,專注於易用性、效能和長期支援。 FSDP 的主要優點之一是減少每個 GPU 上的記憶體佔用量。 這使得可以使用比 DDP 更低的總記憶體來訓練更大的模型,並利用計算和通訊的重疊來有效地訓練模型。 這種降低的記憶體壓力可用於訓練更大的模型或增加批次大小,從而可能有助於提高整體訓練輸送量。 您可以在 此處 閱讀有關 PyTorch FSDP 的更多資訊。

本教學中的 FSDP 功能

  • Transformer 自動封裝策略

  • 混合精度

  • 在裝置上初始化 FSDP 模型

  • 分片策略

  • 反向預取

  • 透過串流到 CPU 儲存模型檢查點

FSDP 如何運作的回顧

在高階層次,FDSP 的運作方式如下

在建構函式中

  • 分片模型參數,且每個排名僅保留其自己的分片

在正向傳遞中

  • 執行 all_gather 以從所有排名收集所有分片,以還原此 FSDP 單元的完整參數,並執行正向計算

  • 捨棄它剛收集的非自有參數分片以釋放記憶體

在反向傳遞中

  • 執行 all_gather 以從所有排名收集所有分片,以還原此 FSDP 單元中的完整參數,並執行反向計算

  • 捨棄非自有參數以釋放記憶體。

  • 執行 reduce_scatter 以同步梯度

微調 HF T5

HF T5 預先訓練模型有四種不同的大小,從具有 6 千萬個參數的小型到具有 110 億個參數的 XXL。 在本教學中,我們示範使用 FSDP 微調 T5 3B 以進行文字摘要,使用 WikiHow 資料集。 本教學的主要重點是強調 FSDP 中可用的不同功能,這些功能有助於訓練超過 3B 個參數的大規模模型。 此外,我們涵蓋了基於 Transformer 模型的特定功能。 本教學的程式碼可在 Pytorch 範例 中找到。

設定

1.1 安裝最新的 PyTorch

pip3 install torch torchvision torchaudio

1.2 資料集設定

請建立一個 data 資料夾,從 wikihowAll.csvwikihowSep.cs 下載 WikiHow 資料集,並將它們放置在 data 資料夾中。 我們將使用來自 summarization_dataset 的 wikihow 資料集。

接下來,我們將以下程式碼片段新增到 Python 腳本 "T5_training.py"。

注意

本教學的完整原始程式碼可在 PyTorch 範例 中找到。

1.3 匯入必要的套件

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分散式訓練設定。 在此,我們使用兩個輔助函數來初始化分散式訓練的程序,然後在訓練完成後清理。 在本教學中,我們將使用 torch elastic,使用 torchrun,它會自動設定工作者 RANKWORLD_SIZE

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 設定 HuggingFace T5 模型

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

我們也在此新增了幾個輔助函數來取得日期和格式化記憶體度量。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定義一個訓練函數

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定義一個驗證函數

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定義一個分散式訓練函數,該函數將模型包裝在 FSDP 中

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

2.5 剖析引數並設定主要函數

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

若要使用 torchrun 執行訓練

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 封裝策略

先前的教學 中所述,auto_wrap_policy 是 FSDP 的功能之一,可讓您輕鬆地自動分片給定的模型,並將模型、最佳化器和梯度分片放入不同的 FSDP 單元中。

對於某些架構,例如 Transformer 編碼器-解碼器,模型的某些部分(如嵌入表)會與編碼器和解碼器共享。在這種情況下,我們需要將嵌入表放置在外部 FSDP 單元中,以便可以從編碼器和解碼器存取。此外,通過註冊 Transformer 的層類別,可以使分片計畫更具通信效率。在 PyTorch 1.12 中,FSDP 添加了此支援,現在我們為 Transformers 提供了一個包裝策略。

它可以按如下方式創建,其中 T5Block 代表 T5 Transformer 層類別(持有 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包裝後的模型,您可以輕鬆地印出模型,並以可視方式檢查分片和 FSDP 單元。

混合精度

FSDP 支援靈活的混合精度訓練,允許使用任意的降低精度類型(例如 fp16 或 bfloat16)。目前,BFloat16 僅在 Ampere GPU 上可用,因此您需要確認原生支援後才能使用它。例如,在 V100 上,BFloat16 仍然可以運行,但由於它以非原生方式運行,因此可能會導致顯著的效能下降。

要檢查是否原生支援 BFloat16,您可以使用以下方法:

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

FSDP 中混合精度的優點之一是,它可以對參數、梯度和緩衝區提供不同精度等級的細緻控制,如下所示:

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

請注意,如果未指定某種類型(參數、reduce、緩衝區),則根本不會進行轉換。

這種靈活性允許使用者進行細緻的控制,例如僅將梯度通信設置為以降低的精度進行,而所有參數/緩衝區的計算都以完整精度完成。這在節點內通信是主要瓶頸,並且參數/緩衝區必須以完整精度以避免準確性問題的情況下可能很有用。可以使用以下策略來實現:

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在 2.4 中,我們只需將相關的混合精度策略添加到 FSDP 包裝器中。

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我們的實驗中,我們觀察到使用 BFloat16 進行訓練可加速高達 4 倍,並且在某些實驗中記憶體減少約 30%,可用於增加批次大小。

在設備上初始化 FSDP 模型

在 1.12 中,FSDP 支援一個 device_id 參數,旨在在 device_id 給定的設備上初始化輸入 CPU 模組。當整個模型無法放入單個 GPU 上,但可以放入主機的 CPU 記憶體中時,這很有用。當指定 device_id 時,FSDP 會將模型移動到指定的設備上,以每個 FSDP 單元為基礎,從而避免 GPU OOM 問題,同時初始化速度比基於 CPU 的初始化快幾倍。

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略

FSDP 分片策略預設設置為完全分片模型參數,梯度和優化器狀態會在所有 ranks 之間進行分片(也稱為 Zero3 分片)。如果您有興趣採用 Zero2 分片策略,其中僅分片優化器狀態和梯度,則 FSDP 支援通過使用 "ShardingStrategy.SHARD_GRAD_OP" 而不是 "ShardingStrategy.FULL_SHARD" 將分片策略傳遞給 FSDP 初始化來實現此功能,如下所示:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

這將減少 FSDP 中的通信開銷,在這種情況下,它在正向傳遞之後和通過反向傳遞保持完整的參數。

這在反向傳遞期間節省了 all_gather,因此通信較少,但代價是更高的記憶體佔用。請注意,完整模型參數會在反向傳遞結束時釋放,並且 all_gather 將在下一個正向傳遞中發生。

反向預取

反向預取設定控制何時應請求下一個 FSDP 單元的參數。通過將其設置為 BACKWARD_PRE,可以開始請求下一個 FSDP 單元的參數,並在當前單元的計算開始之前更快到達。這重疊了 all_gather 通信和梯度計算,可以提高訓練速度,但會略微增加記憶體消耗。它可以在 2.4 的 FSDP 包裝器中使用,如下所示:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 具有兩種模式,BACKWARD_PREBACKWARD_POSTBACKWARD_POST 意味著在當前 FSDP 單元處理完成之前,不會請求下一個 FSDP 單元的參數,從而最大程度地減少記憶體開銷。在某些情況下,使用 BACKWARD_PRE 可以將模型訓練速度提高多達 2-10%,並且對於較大的模型,可以注意到更高的速度提升。

模型檢查點儲存,通過流式傳輸到 Rank0 CPU

要使用 FULL_STATE_DICT 儲存來儲存模型檢查點,該儲存以與本地模型相同的方式儲存模型,PyTorch 1.12 提供了一些實用程式來支援儲存較大的模型。

首先,可以指定 FullStateDictConfig,允許僅在 rank 0 上填充 state_dict 並卸載到 CPU。

使用此配置時,FSDP 將 allgather 模型參數,並將它們一個接一個地卸載到 CPU,僅在 rank 0 上。當最終儲存 state_dict 時,它將僅在 rank 0 上填充,並包含 CPU tensors。這避免了模型大於單個 GPU 記憶體時可能發生的 OOM,並允許使用者檢查點大小約為使用者機器上可用 CPU RAM 的模型。

此功能可以如下運行:

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

總結

在本教程中,我們介紹了 Pytorch 1.12 中 FSDP 的許多新功能,並使用了 HF T5 作為運行範例。使用正確的包裝策略(尤其是對於 Transformer 模型),以及混合精度和反向預取,應加快您的訓練運行速度。此外,諸如在設備上初始化模型以及通過流式傳輸到 CPU 進行檢查點儲存等功能應有助於避免處理大型模型時出現 OOM 錯誤。

我們正在積極努力為下一個版本向 FSDP 添加新功能。如果您有回饋、功能要求、問題或在使用 FSDP 時遇到問題,請隨時通過在 PyTorch Github 儲存庫中提出問題來與我們聯繫。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源