捷徑

FSDP 筆記

FSDP 預取細節

為了使 forward all-gather 與 forward 計算重疊,有兩種可能的機制

  1. 隱式正向預取(始終啟用)

  2. 顯式前向預取(forward_prefetch=True

隱式 forward 預取是指依賴於從一個獨立的 CUDA stream 發出 all-gather 操作,以便允許 all-gather 與在其之前發出的 forward 計算(從 CPU 的角度來看)重疊。 例如,如果我們有 layer 0 的 all-gather -> layer 0 的 forward 計算 -> layer 1 的 all-gather -> …,那麼 layer 1 的 all-gather 可以與 layer 0 的 forward 計算重疊,即使 CPU 執行緒是在其之後發出的。(第 1 個 all-gather 將無法與任何東西重疊。)

顯式 forward 預取是指更改 CPU 執行緒的發佈順序:例如 layer 0 的 all-gather -> layer 1 的 all-gather -> layer 0 的 forward 計算 -> …。 在 eager 模式下,通常無法知道哪個 layer 是下一個 layer(例如,範例中的 layer 1),因為仍在 layer 0 上執行。 因此,顯式 forward 預取應該只用於執行順序在迭代之間是固定的模型(我們有時稱之為「靜態圖」)。 FLAVA 就是不滿足此約束的一個範例。)

顯式 forward 預取只節省了發佈 layer 的 forward 計算核心的時間,但代價是必須在當前的 all-gather 輸出張量仍在使用的情況下分配下一個 all-gather 的輸出張量。 通過在當前的 forward 計算核心之前發佈下一個 all-gather,下一個 all-gather 可以在 GPU 上更早地開始。 對於大多數 LLM 工作負載而言,情況並非如此,因此沒有理由啟用 forward_prefetch=True

相反,對於 backward,我們必須使用顯式 backward 預取,否則 communication 和 computation 之間將沒有任何重疊。原因是我們對 all-gather 和 reduce-scatter 使用單個 NCCL process group (部分原因是早期 NCCL 版本中,在同一設備上的相同 ranks 上同時使用多個 NCCL process group 是不安全的)。 單個 NCCL process group 意味著單個內部 NCCL stream,reduce-scatter 和 all-gather 在其上序列執行。 因此,除非我們顯式地將 CPU 的發佈順序重新排序為下一個 all-gather -> 當前的 reduce-scatter,否則當前的 reduce-scatter 會阻塞下一個 all-gather,從而阻塞下一個 backward 計算,並阻止當前的 reduce-scatter 重疊。

Communication payload 大小

在 FSDP 中,communication 包括:

  1. forward 中參數的 all-gather

  2. backward 中參數的 all-gather

  3. backward 中梯度 (gradients) 的 reduce-scatter

如果使用 activation checkpointing(checkpoint()),則沒有額外的 communication,因為參數已經在 backward 期間被預取。

在 FSDP 設計中,每個 rank 的 communication payload 是這樣確定的:每次調用 FullyShardedDataParallel 都會創建一個 communication group,該 group 由 module.parameters() 中的參數組成,但不包括已分配給巢狀 FullyShardedDataParallel 實例的任何參數。 例如,對於 Llama,如果您將 FullyShardedDataParallel 應用於每個 transformer block 以及 root module,那麼每個 transformer block 都會有一個 communication group,最後還有一個包含初始 embedding 和最終 linear 的 communication group。 每個 communication group 對應於單個 all-gather 調用和單個 reduce-scatter 調用。 這樣,您應用 FullyShardedDataParallel 的方式決定了 communication 的大小。 一般來說,將 FSDP 應用於每個 transformer block 是 LLM 的一個很好的 heuristic 方法,並且在當前設計下很難做得更好。

讓我們考慮一個例子,我們有一個基於 Transformer 的模型,該模型在 8 個 GPU 上進行 sharding,其中 sharding 僅在 transformer block 級別發生,並且每個 transformer block 包含 1.6B 個參數,並且這些參數採用 fp32 (每個 4 bytes)。 這意味著一旦 sharded 後,每個 transformer block 將在每個 rank 上包含 0.2B 個參數。

  • forward pass 將以 0.2*4 = 0.8GB 的 chunks 在 all-gather 中進行 communication

  • backward pass 將 communication 2 次,每次 0.8GB (1 個 all-gather 和 1 個 reduce-scatter)

換句話說,將有 3 個 communication,每個 payload 為 0.8GB。 如果模型由 10 個 transformer block 組成,那麼總共將有 30 個 communication,總計 30*0.8=24GB

將每個 rank 的每次 communication 的 payload 大小形式化為 total_transformer_block_params_in_B*dtype_bytes/num_gpus (GB)。

請注意,在本例中,我們沒有包括 embedding 所需的額外 communication,這也應該考慮在內。 並且計算將取決於輸入和輸出 embeddings 是否 tied。 如果它們沒有 tied,則將有 2 倍以上的 communication。

FSDP 緩衝區大小

首先,讓我們介紹為 communication 分配的緩衝區

forward 目前需要 2 倍的 all-gather 緩衝區大小。 這是為什麼

正如 FSDP 預取細微差別 中解釋的那樣,在顯式 forward 預取(forward_prefetch=True`) 的情況下,對於 layer 0 all-gather -> layer 0 forward 計算 -> layer 1 all-gather 存在對 2 all-gather 大小的緩衝區的需求,因為一個緩衝區用於當前的 ``forward,而另一個用於進行預取。

理論上,在隱式的 forward 預取 (forward_prefetch=False,預設值) 的情況下,相同的序列應該只需要 1 個 buffer,但實際上仍然需要 2 倍 all-gather 大小的 buffers。原因是,在 flat-parameter FSDP 設計中,我們不會從 all-gather buffer 中複製出來。用於計算的參數直接視窗化到 all-gather buffer 中(事實上,“flat parameter”的主要優點正是這個原因)。在這種情況下,雖然 'layer 1 all-gather' 與 'layer 0 forward compute' 重疊,但 'layer 0 forward compute' 正在使用視窗化到 'layer 0 all-gather' buffer 中的參數。

一個自然的問題是,什麼時候你會想要 forward_prefetch=False 呢?對於靜態圖模型(如大多數 LLM),有一個主要的技術原因。更確切地說,實際上,我們很快為一些 CPU 限制的內部模型添加了這個選項,並且沒有在單元測試中測試每個程式碼路徑,因此我們對它的信心較低。forward_prefetching=False 可能更容易理解,因為我們不必檢查記錄的 forward 順序作為可能的“故障模式”;一個模組的 all-gather 總是可以根據其 profiler 追蹤中的 record_function 標籤找到。

backward 目前至少需要 2 倍 all-gather buffer 大小,並且可能需要更多一點。以下是原因

目前的 FSDP 設計使用 recordStream 來管理在一個 stream 中產生並在另一個 stream 中消耗的分配,這可能導致比預期更多的記憶體使用量。具體多多少取決於 GPU 核心的時序相對於 CPU 的時序,因此可能是“非確定性的”。limit_all_gathers=True 參數是一種緩解措施 - 有關更多詳細資訊,請參閱此討論 FSDP & CUDACachingAllocator

現有 FSDP 與 autograd 的協作方式

  • 現有 FSDP all-gather 了 flat_param,它是 autograd 的葉節點。

  • 它呼叫 torch.split 以取得 flat_param 的 1D 視窗,該視窗對應於其組成的原始參數。

  • 它在每個 1D 分割上呼叫 torch.view 以視窗化回 ND。

  • 這意味著在 backward 中,我們最終會得到 ViewBackward (ND -> 1D) 和 SplitWithSizesBackward(它是一個 concat)。特別是,每個單獨的梯度都作為一個單獨的分配來計算,並且會發生顯式的 concat 來構造 reduce-scatter 輸入 buffer。這實際上意味著在那個峰值記憶體點,reduce-scatter 的 buffer 大小是 2 倍。

總之,對於 backward,reduce-scatter 的 buffer 大小約為 2 倍,加上任何 recordStream 效應。

其次,讓我們討論額外的 buffers

一旦從所有 ranks 收集了分片參數,它們就需要一個額外的 buffer,大小為 *total_transformer_block_params_in_B*dtype_bytes* 用於完整參數 - 因此延續之前的範例,如果每個 transformer block 是 1.6B 參數,並且參數是 fp32,那麼它將是 *1.6*4=6.4GB* buffer。

並且需要 2 個這樣的 buffers,因為目前有一個正在使用,另一個正在預取。

總結一下,我們有

  1. total_transformer_block_params_in_B*dtype_bytes/num_gpus 的 2 倍通訊 buffers

  2. ``total_transformer_block_params_in_B*dtype_bytes 的 2 倍未分片 transformer block 參數 buffer

或者如果您一直在關注這個範例

  1. 2*1.6*4/8=1.6GB

  2. 2**1.6*4=12.8GB

總共 14.4GB

現在讓我們簡要討論一下 embeddings 發生了什麼,因為我們已經將它們排除在計算之外

鑑於我們討論的規則,您在以“通訊 buffer 大小的確定方式如下”開頭的註釋中包含了該規則,我們可以按如下方式分析

  • 假設我們將 FSDP 應用於根模組(例如 Transformer 類別)。假設我們進一步將 FSDP 應用於每個 transformer block(例如 TransformerBlock 類別)。

  • 最常見的是,embedding 和最終的線性投影是根 Transformer 類別的直接子項。

  • 按照我們的規則,這意味著 embedding 和最終的線性投影被分配給根 Transformer 的 flat 參數。

  • 我們有 *另一個* 特殊規則,即根不會在 forward 之後釋放其參數,因為它們無論如何都會立即在 backward 中進行 all-gather。

  • 總而言之,這意味著根的 flat 參數(包括 embedding 和最終投影)被 all-gather 以開始 forward,並保存在 GPU 記憶體中直到 backward 結束。

  • 如果 embedding 和最終線性沒有權重綁定,那麼我們 *可以* 進一步將 FSDP 應用於 embedding 和最終線性。對於權重綁定參數,我們要求它們屬於同一個 flat 參數(否則會被重複計算)。這將允許 embedding 在 forward 中使用後釋放,並且僅在 backward 結束時進行 all-gather。

  • 希望這能更好地理解 – 每個 FSDP 模組都會在其 module.parameters 中分配參數,除了已經分配給另一個巢狀 FSDP 模組的參數,並且 FSDP 模組的 forward 定義了其參數的“存活”間隔。因此,巢狀 nn.Module 結構會影響 all-gather/free 排程,從而影響記憶體/吞吐量效能。

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源