torch.distributed.fsdp.fully_shard¶
PyTorch FSDP2 (fully_shard
)¶
PyTorch FSDP2 提供完全分片資料平行 (FSDP) 實作,目標是在高效能的 Eager 模式下運作,同時使用基於每個參數的分片以提高可用性。
如果您是 FSDP 的新手,我們建議您從 FSDP2 開始,因為它具有更佳的易用性。
如果您目前正在使用 FSDP1,請考慮評估以下差異,以了解是否應該切換到 FSDP2。
與 PyTorch FSDP1 (FullyShardedDataParallel
) 相比
FSDP2 使用基於
DTensor
的 dim-0 每參數分片,以實現更簡化的分片表示,相較於 FSDP1 的平面參數分片,同時保持相似的吞吐量效能。 更具體地說,FSDP2 在資料並行工作者之間按照 dim-0 對每個參數進行分塊(使用torch.chunk(dim=0)
),而 FSDP1 會將一組張量展平、連接和分塊在一起,這使得推斷每個工作者上存在哪些資料以及重新分片到不同的並行性變得複雜。 每參數分片提供了更直觀的使用者體驗,放寬了對凍結參數的限制,並允許無通信(分片)的 state dict,否則在 FSDP1 中需要 all-gather。FSDP2 實現了一種不同的記憶體管理方法來處理多串流的使用,避免了
torch.Tensor.record_stream
。這確保了確定性和預期的記憶體使用量,並且不需要像 FSDP1 的limit_all_gathers=True
那樣阻塞 CPU。FSDP2 公開了用於手動控制預取和集體調度的 API,允許高階使用者進行更多客製化。有關詳細資訊,請參閱下面的
FSDPModule
上的方法。FSDP2 簡化了一些 API 介面:例如,FSDP2 不直接支援完整的 state dict。相反,使用者可以使用
DTensor
API(如DTensor.full_tensor()
)或使用更高等級的 API(如 PyTorch Distributed Checkpoint 的 distributed state dict API)將包含DTensor
的分片 state dict 重新分片到完整的 state dict 本身。 此外,還刪除了一些其他的參數;有關詳細資訊,請參閱此處。
如果您是第一次使用 FSDP,或者如果以上任何一項符合您的使用案例,我們建議您考慮使用 FSDP2。
有關系統設計和實作的詳細資訊,請參閱此 RFC。
注意
torch.distributed.fsdp.fully_shard
目前處於原型狀態並且正在開發中。核心 API 可能不會更改,但如有必要,我們可能會進行一些 API 變更。
前端 API 是 fully_shard
,可以對 module
呼叫。
- torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy())[source]¶
將完全分片資料並行性 (FSDP) 應用於
module
,其中 FSDP 將模組參數、梯度和最佳化器狀態在資料並行工作者之間分片,以節省記憶體,但會增加通訊成本。在初始化時,FSDP 會根據
mesh
給定的資料並行工作者對模組的參數進行分片。 在 forward 之前,FSDP 會跨資料並行工作者 all-gather 分片的參數,以取得用於 forward 計算的未分片參數。 如果reshard_after_forward
為True
,則 FSDP 會在 forward 之後釋放未分片的參數,並在 backward 中重新 all-gather 它們,然後進行梯度計算。 在梯度計算之後,FSDP 會釋放未分片的參數,並跨資料並行工作者 reduce-scatter 未分片的梯度。此實作將分片的參數表示為在 dim-0 上分片的
DTensor
,而未分片的參數將像module
上的原始參數一樣(例如,如果原本是torch.Tensor
,則為torch.Tensor
)。 在module
上的模組 forward pre-hook 會 all-gather 參數,而module
上的模組 forward hook 會釋放它們(如果需要)。 類似的 backward hook 會 all-gather 參數,然後釋放參數並 reduce-scatter 梯度。由於將多個張量組合在一起進行一次集體通訊對於通訊效率至關重要,因此此實作使這種組合成為一等公民。 對
module
呼叫fully_shard()
會建構一個群組,該群組包含module.parameters()
中的參數,但那些已經從先前對子模組的呼叫分配給群組的參數除外。 這意味著應該從下而上地在您的模型上呼叫fully_shard()
。 每個群組的參數都在一次集體通訊中進行 all-gather,並且其梯度在一次集體通訊中進行 reduce-scatter。 將模型劃分為多個群組(「逐層」)可以節省最大的記憶體並實現通訊/計算重疊。 使用者通常不應僅在最頂層的根模組上呼叫fully_shard()
。- 參數
module (Union[nn.Module, List[nn.Module]) – 要使用 FSDP 分片並群組在一起進行通訊的模組。
mesh (Optional[DeviceMesh]) – 這個資料並行網格定義了分片和裝置。如果是 1D,則參數會完全分片到 1D 網格 (FSDP) 上,並使用
(Shard(0),)
位置。如果是 2D,則參數會在第 1 個維度上分片,並在第 0 個維度上複製 (HSDP),並使用(Replicate(), Shard(0))
位置。網格的裝置類型給出了用於通訊的裝置類型;如果是一個 CUDA 或類似 CUDA 的裝置類型,則我們使用目前的裝置。reshard_after_forward (Union[bool, int]) –
這控制了正向傳播後參數的行為,並且可以權衡記憶體和通訊
如果
True
,則會在正向傳播後重新分片參數,並在反向傳播中重新進行 all-gather。如果
False
,則會在正向傳播後將未分片的參數保留在記憶體中,並避免在反向傳播中進行 all-gather。如果是一個
int
,則表示正向傳播後重新分片到的世界大小。它應該是mesh
分片維度大小的一個非平凡除數(即排除 1 和維度大小本身)。一種選擇可能是節點內大小(例如torch.cuda.device_count()
)。這允許反向傳播中的 all-gather 在較小的世界大小上進行,但代價是比設定為True
更高的記憶體使用量。根 FSDP 狀態的值被特別設定為
False
作為一種啟發式方法,因為它的參數通常會立即被 all-gather 用於反向傳播。正向傳播後,註冊到模組的參數取決於此:如果
True
,則註冊的參數是分片的參數;如果False
,則為未分片的參數;否則為重新分片到較小網格的參數。要在正向傳播和反向傳播之間修改參數,則註冊的參數必須是分片的參數。對於False
或一個int
,可以通過reshard()
手動重新分片來完成。
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 這個可調用物件可用於覆蓋參數的分片位置,以便在非 dim-0 的維度上分片參數。如果這個可調用物件返回一個
Shard
位置(不是None
),則 FSDP 將根據該位置進行分片(例如Shard(1)
)。如果分片在非零維度上,我們目前需要偶數分片,即該維度上的張量維度大小必須可被 FSDP 分片網格大小整除。mp_policy (MixedPrecisionPolicy) – 這控制了混合精度策略,它為此模組提供了參數/縮減混合精度。 有關詳細資訊,請參閱
MixedPrecisionPolicy
。offload_policy (OffloadPolicy) – 這控制了卸載策略,它提供了參數/梯度/優化器狀態卸載。 有關詳細資訊,請參閱
OffloadPolicy
及其子類別。
呼叫 fully_shard(module)
會動態建構一個新的類別,該類別是 type(module)
和 FSDP 類別 FSDPModule
的子類別。例如,如果我們在模組 linear: nn.Linear
上呼叫 fully_shard(linear)
,則 FSDP 會建構一個新的類別 FSDPLinear
並將 linear
的類型變更為此。否則,fully_shard
不會變更模組結構和參數的完整名稱。類別 FSDPModule
允許在模組上提供一些 FSDP 特定的方法。
- class torch.distributed.fsdp.FSDPModule(*args, **kwargs)¶
-
- set_is_last_backward(is_last_backward)[source][來源]¶
設定下一次反向傳播是否為最後一次。在最後一次反向傳播時,FSDP 會等待未完成的梯度縮減,並清除內部資料結構以進行反向預取。這對於微批次處理非常有用。
- set_modules_to_backward_prefetch(modules)[source][source]¶
設定此 FSDP 模組應該在反向傳播中顯式預取 all-gather 的 FSDP 模組。這會覆蓋預設的反向預取實作,該實作基於反向後正向順序預取下一個 FSDP 模組。
傳遞包含前一個 FSDP 模組的單例列表,可以提供與預設重疊行為相同的 all-gather 重疊行為。若要更積極地重疊,則需要傳遞長度至少為 2 的列表,這將使用更多保留記憶體。
- 參數
modules (List[FSDPModule]) – 要預取的 FSDP 模組。
- set_modules_to_forward_prefetch(modules)[source][source]¶
設定此 FSDP 模組應該在正向傳播中顯式預取 all-gather 的 FSDP 模組。預取會在該模組的 all-gather 複製輸出後執行。
傳遞包含下一個 FSDP 模組的單例列表,可以提供與預設重疊行為相同的 all-gather 重疊行為,除了預取的 all-gather 會更早地從 CPU 發出。若要更積極地重疊,則需要傳遞長度至少為 2 的列表,這將使用更多保留記憶體。
- 參數
modules (List[FSDPModule]) – 要預取的 FSDP 模組。
- set_post_optim_event(event)[source][source]¶
為根 FSDP 模組設定一個最佳化器步驟後事件,以等待 all-gather streams 完成。
預設情況下,根 FSDP 模組會在目前 stream 上等待 all-gather streams 完成,以確保最佳化器步驟在 all-gather 之前完成。然而,如果在最佳化器步驟之後有不相關的計算,這可能會引入錯誤的依賴關係。這個 API 允許使用者提供他們自己的事件來等待。根模組等待事件後,該事件會被丟棄,因此每次迭代都應使用新的事件呼叫此 API。
- 參數
event (torch.Event) – 在最佳化器步驟之後記錄的事件,用於等待 all-gather streams 完成。
- set_reduce_scatter_divide_factor(factor)[source][source]¶
為 reduce-scatter 設定自訂的除法因子。這會變成使用 NCCL 的 PreMulSum 的自訂 reduce 運算,允許在縮減之前乘以該因子。
- 參數
factor (float) – 自訂的除法因子。
- set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]¶
設定模組是否應該進行 all-reduce 梯度。這可用於實作梯度累加,僅使用 reduce-scatter,而不使用 HSDP 的 all-reduce。
- set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]¶
設定模組是否應該同步梯度。這可用於實作無需通訊的梯度累加。對於 HSDP,這會同時控制 reduce-scatter 和 all-reduce。
- set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]¶
設定模組是否應該在反向傳播後重新分片參數。這可以用於梯度累加期間,以更高的記憶體使用量來換取減少的通訊,因為在下一次正向傳播之前,無需重新 all-gather 未分片的參數。
- set_unshard_in_backward(unshard_in_backward)[source][source]¶
設定是否需要在反向傳播中解除 FSDP 模組的參數分片。這可用於專家案例,當使用者知道此 FSDP 模組的參數群組中的所有參數都不需要用於反向計算時 (例如,嵌入層)。
- unshard(async_op=False)[原始碼][原始碼]¶
透過分配記憶體並全部收集 (all-gathering) 參數來解除模組的參數分片。此方法 *並非* 遞迴的。解除分片遵循
MixedPrecisionPolicy
,因此如果設定了param_dtype
,它將按照param_dtype
進行全部收集。- 參數
async_op (bool) – 如果
True
,則傳回一個UnshardHandle
,其中包含一個wait()
方法,用於等待解除分片操作完成。 如果False
,則傳回None
並在此函數內部等待控制代碼。- 傳回類型
注意
如果
async_op=True
,則 FSDP 將在模組的 pre-forward 中等待尚未完成的解除分片。 只有當等待應該在 pre-forward 之前發生時,使用者才需要顯式呼叫wait()
。
- class torch.distributed.fsdp.UnshardHandle¶
用於等待
FSDPModule.unshard()
操作完成的控制代碼。
- torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[原始碼]¶
在
module
上註冊一個方法,以便將其視為 FSDP 的 forward 方法。FSDP 在 pre-forward 階段全部收集參數,並選擇性地在 post-forward 階段釋放參數 (取決於
reshard_after_forward
)。 預設情況下,FSDP 只知道對nn.Module.forward()
執行此操作。 此函數會修補使用者指定的方法,以分別在該方法之前/之後執行 pre/post-forward hook。 如果module
不是FSDPModule
,則這是一個 no-op。
- class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)¶
這會設定 FSDP 的混合精度。 與 autocast 不同,這是在模組層級而不是運算層級套用混合精度,這表示低精度 activation 會儲存以用於反向傳播,並且高精度到低精度的轉換只會在模組邊界產生。
FSDP 非常適合模組層級的混合精度,因為無論如何它都會將高精度分片參數保留在記憶體中。 換句話說,FSDP 不需要任何額外的記憶體來保留參數的高精度副本,以用於優化器步驟。
- 變數
param_dtype (Optional[torch.dtype]) – 這指定了未分片參數的 dtype,因此也是 forward/backward 計算和參數全部收集的 dtype。 如果這是
None
,則未分片的參數會使用原始 dtype。 優化器步驟會使用原始 dtype 中已分片的參數。 (預設值:None
)reduce_dtype (選擇性[torch.dtype]) – 此參數指定梯度縮減 (例如 reduce-scatter 或 all-reduce) 的 dtype。如果此參數為
None
,但param_dtype
不為None
,則縮減會使用計算 dtype。這可用於以完整精度執行梯度縮減,同時使用低精度進行計算。如果也透過set_requires_gradient_sync()
停用梯度縮減,則 FSDP 將使用reduce_dtype
累積梯度。(預設值:None
)output_dtype (選擇性[torch.dtype]) – 此參數指定用於轉換浮點正向傳播輸出的 dtype。這可用於幫助實作不同模組具有不同混合精度策略的情況。(預設值:
None
)cast_forward_inputs (bool) – 此參數指定 FSDP 是否應將正向傳播的浮點輸入張量轉換為
param_dtype
。
- class torch.distributed.fsdp.OffloadPolicy¶
此基底類別代表不卸載的策略,僅用作
offload_policy
參數的預設值。
- class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)¶
此卸載策略會將參數、梯度和最佳化器狀態卸載到 CPU。分片參數在 all-gather 之前會複製到主機到裝置。All-gather 參數會根據
reshard_after_forward
釋放。分片梯度會在反向傳播中從裝置複製到主機,而最佳化器步驟會在 CPU 上使用 CPU 最佳化器狀態執行。- 變數
pin_memory (bool) – 是否釘選分片參數和梯度記憶體。釘選記憶體可以更有效率地進行 H2D/D2H 複製,並且複製可以與計算重疊。但是,釘選的記憶體不能被其他進程使用。如果 CPU 記憶體不足,請將此參數設定為
False
。(預設值:True
)