• 文件 >
  • FullyShardedDataParallel
快捷方式

FullyShardedDataParallel

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source][source]

一個用於在資料平行工作者間分片模組參數的包裝器。

其靈感來自 Xu et al. 以及來自 DeepSpeed 的 ZeRO Stage 3。FullyShardedDataParallel 通常簡寫為 FSDP。

要了解 FSDP 的內部機制,請參閱 FSDP 筆記

範例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

使用 FSDP 涉及到包裝您的模組,然後在之後初始化您的最佳化器。這是必需的,因為 FSDP 會更改參數變數。

設定 FSDP 時,您需要考慮目標 CUDA 裝置。如果裝置有一個 ID (dev_id),您有三個選項

  • 將模組放置在該裝置上

  • 使用 torch.cuda.set_device(dev_id) 設定裝置

  • dev_id 傳遞到 device_id 建構子引數中。

這確保了 FSDP 實例的計算裝置是目標裝置。對於選項 1 和 3,FSDP 初始化始終發生在 GPU 上。對於選項 2,FSDP 初始化發生在模組當前的裝置上,這可能是一個 CPU。

如果您正在使用 sync_module_states=True 標誌,您需要確保模組位於 GPU 上,或使用 device_id 引數來指定 FSDP 將在 FSDP 建構子中將模組移動到的 CUDA 裝置。這是必要的,因為 sync_module_states=True 需要 GPU 通訊。

FSDP 還負責將輸入張量移動到 forward 方法的 GPU 計算裝置,因此您無需手動將它們從 CPU 移動。

對於 use_orig_params=True, ShardingStrategy.SHARD_GRAD_OP 暴露的是未分片的參數,而不是 forward 之後的分片參數,這與 ShardingStrategy.FULL_SHARD 不同。 如果您想檢查梯度,您可以使用 summon_full_params 方法,並搭配 with_grads=True

使用 limit_all_gathers=True 時,您可能會看到 FSDP 在 pre-forward 階段出現間隙,其中 CPU 線程沒有發出任何核心。 這是故意的,並顯示了速率限制器的效果。 以這種方式同步 CPU 線程可以防止為後續的 all-gather 過度分配記憶體,並且實際上不應延遲 GPU 核心的執行。

由於 autograd 相關的原因,FSDP 在 forward 和 backward 計算期間,會將受管理模組的參數替換為 torch.Tensor 檢視 (views)。 如果您的模組的 forward 依賴於保存的參數引用,而不是每次迭代重新獲取引用,那麼它將看不到 FSDP 新建立的檢視,並且 autograd 將無法正常工作。

最後,當使用 sharding_strategy=ShardingStrategy.HYBRID_SHARD,且分片處理群組為節點內 (intra-node),而複製處理群組為節點間 (inter-node) 時,設定 NCCL_CROSS_NIC=1 可以幫助改善某些叢集設定中,複製處理群組的 all-reduce 時間。

限制

使用 FSDP 時,需要注意以下幾個限制

  • 當使用 CPU 卸載時,FSDP 目前不支援在 no_sync() 之外的梯度累積。 這是因為 FSDP 使用新減少的梯度,而不是與任何現有梯度累積,這可能會導致不正確的結果。

  • FSDP 不支援執行包含在 FSDP 實例中的子模組的 forward 傳遞。 這是因為子模組的參數將被分片,但子模組本身不是 FSDP 實例,因此其 forward 傳遞不會適當地 all-gather 完整參數。

  • 由於 FSDP 註冊 backward hooks 的方式,它不支援 double backwards。

  • FSDP 在凍結參數時有一些限制。 對於 use_orig_params=False,每個 FSDP 實例必須管理所有凍結或所有非凍結的參數。 對於 use_orig_params=True,FSDP 支援混合凍結和非凍結參數,但建議避免這樣做,以防止高於預期的梯度記憶體使用量。

  • 截至 PyTorch 1.12,FSDP 對共享參數的支援有限。 如果您的用例需要增強的共享參數支援,請在此 issue 中發文。

  • 您應該避免在 forward 和 backward 之間修改參數,而不使用 summon_full_params 上下文,因為修改可能不會持續。

參數
  • module (nn.Module) – 這是要用 FSDP 包裝的模組。

  • process_group (選填[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 這是用於對模型進行分片處理的 process group,因此也是用於 FSDP 的 all-gather 和 reduce-scatter 集合通訊的 process group。如果為 None,則 FSDP 會使用預設的 process group。對於混合分片策略,例如 ShardingStrategy.HYBRID_SHARD,使用者可以傳入一個 process group 的 tuple,分別代表用於分片和複製的 group。如果為 None,則 FSDP 會為使用者建構 process group,以便在節點內進行分片,並在節點間進行複製。(預設:None

  • sharding_strategy (選填[ShardingStrategy]) – 此設定會配置分片策略,該策略可能會在記憶體節省和通訊開銷之間進行權衡。詳情請參閱 ShardingStrategy。(預設:FULL_SHARD

  • cpu_offload (選填[CPUOffload]) – 此設定會配置 CPU 卸載。如果設定為 None,則不會發生 CPU 卸載。詳情請參閱 CPUOffload。(預設:None

  • auto_wrap_policy (選填[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –

    此設定指定將 FSDP 應用於 module 的子模組的策略,這是通訊和計算重疊所必需的,因此會影響效能。如果為 None,則 FSDP 僅適用於 module,並且使用者應手動將 FSDP 應用於父模組(從下而上)。為了方便起見,此設定直接接受 ModuleWrapPolicy,這允許使用者指定要包裝的模組類別(例如,transformer block)。否則,此設定應為一個 callable 物件,它接受三個參數 module: nn.Modulerecurse: boolnonwrapped_numel: int,並且應返回一個 bool,指定是否應在 recurse=False 時將 FSDP 應用於傳入的 module,或是否應在 recurse=True 時繼續遍歷模組的子樹。使用者可以向 callable 物件添加其他參數。torch.distributed.fsdp.wrap.py 中的 size_based_auto_wrap_policy 提供了一個 callable 物件的範例,如果其子樹中的參數超過 100M numel,則會將 FSDP 應用於模組。我們建議在應用 FSDP 後列印模型,並根據需要進行調整。

    範例

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     nonwrapped_numel: int,
    >>>     # Additional custom arguments
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return nonwrapped_numel >= min_num_params
    >>> # Configure a custom `min_num_params`
    >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
    

  • backward_prefetch (選填[BackwardPrefetch]) – 此設定配置了 all-gather 的顯式反向預取。如果為 None,則 FSDP 不會進行反向預取,並且在反向傳遞中沒有通訊和計算重疊。詳情請參閱 BackwardPrefetch。(預設:BACKWARD_PRE

  • mixed_precision (選填[MixedPrecision]) – 此設定配置了 FSDP 的原生混合精度。如果設定為 None,則不使用混合精度。否則,可以設定參數、緩衝區和梯度縮減的 dtype。詳情請參閱 MixedPrecision。(預設:None

  • ignored_modules (選填[Iterable[torch.nn.Module]]) – 此實例會忽略這些模組自己的參數以及子模組的參數和緩衝區。ignored_modules 中直接的模組都不應是 FullyShardedDataParallel 實例,並且如果任何已經建構的 FullyShardedDataParallel 子模組嵌套在此實例下,則不會被忽略。當使用 auto_wrap_policy 時,或者如果參數的分片不受 FSDP 管理時,可以使用此引數來避免以模組粒度對特定參數進行分片。(預設:None

  • param_init_fn (選填[Callable[[nn.Module], None]]) –

    一個 Callable[torch.nn.Module] -> None,用於指定如何將目前位於 meta 裝置上的模組初始化到實際裝置上。從 v1.12 開始,FSDP 透過 is_meta 檢測參數或緩衝區位於 meta 裝置上的模組,如果指定了 param_init_fn,則應用它,否則調用 nn.Module.reset_parameters()。對於這兩種情況,實作都應初始化模組的參數/緩衝區,而不是其子模組的參數/緩衝區。這是為了避免重新初始化。此外,FSDP 還支援透過 torchdistX 的 (https://github.com/pytorch/torchdistX) deferred_init() API 進行延遲初始化,其中延遲的模組透過調用 param_init_fn(如果指定)或 torchdistX 的預設 materialize_module() 來初始化。如果指定了 param_init_fn,則它會應用於所有 meta 裝置模組,這意味著它應該根據模組類型進行區分處理。FSDP 在參數扁平化和分片之前調用初始化函式。

    範例

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module: nn.Module):
    >>>     # E.g. initialize depending on the module type
    >>>     ...
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_id (Optional[Union[int, torch.device]]) – 一個 inttorch.device,用於指定 FSDP 初始化發生的 CUDA 裝置,包括模組初始化(如果需要)和參數分片。如果 module 在 CPU 上,則應指定此項以提高初始化速度。如果已設定預設的 CUDA 裝置(例如,透過 torch.cuda.set_device),則使用者可以將 torch.cuda.current_device 傳遞給此項。(預設:None

  • sync_module_states (bool) – 如果 True,則每個 FSDP 模組將從 rank 0 廣播模組參數和緩衝區,以確保它們在各個 rank 上複製(從而為此建構子增加通信開銷)。這有助於以記憶體有效的方式透過 load_state_dict 載入 state_dict 檢查點。有關範例,請參閱 FullStateDictConfig。(預設:False

  • forward_prefetch (bool) – 如果 True,則 FSDP 會在當前向前計算之前顯式預取下一次向前傳遞的 all-gather。這僅適用於受 CPU 限制的工作負載,在這種情況下,更早地發出下一次 all-gather 可能會改善重疊。這僅應用於靜態圖模型,因為預取遵循第一次迭代的執行順序。(預設:False

  • limit_all_gathers (bool) – 如果 True,則 FSDP 顯式地同步 CPU 線程,以確保僅來自兩個連續 FSDP 實例的 GPU 記憶體使用量(當前執行計算的實例和下一個 all-gather 被預取的實例)。如果 False,則 FSDP 允許 CPU 線程發出所有 all-gather,而無需任何額外同步。(預設:True)我們通常將此功能稱為「速率限制器」。僅應在特定 CPU 限制的工作負載中將此標誌設定為 False,這些工作負載具有較低的記憶體壓力,在這種情況下,CPU 線程可以積極地發出所有核心,而無需考慮 GPU 記憶體使用量。

  • use_orig_params (bool) – 將此設定為 True 會讓 FSDP 使用 module 的原始參數。 FSDP 透過 nn.Module.named_parameters() 向使用者公開這些原始參數,而不是 FSDP 的內部 FlatParameter。 這意味著最佳化器步驟在原始參數上執行,從而實現每個原始參數的超參數。 FSDP 保留原始參數變數,並在其取消分片和分片形式之間操作其資料,它們始終是分別進入底層取消分片或分片 FlatParameter 的檢視。 使用目前的演算法,分片形式始終是 1D,失去了原始張量結構。 原始參數可能具有給定 rank 的所有、部分或沒有資料存在。 在沒有的情況下,其資料將像一個大小為 0 的空張量。 使用者不應編寫依賴於給定 rank 的分片形式的原始參數中存在的資料的程式。 使用 torch.compile() 需要 True。 將此設定為 False 會透過 nn.Module.named_parameters() 向使用者公開 FSDP 的內部 FlatParameter。(預設:False

  • ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – 將被此 FSDP 實例忽略的參數或模組,這意味著參數未被分片,並且它們的梯度未跨 rank 縮減。 此參數與現有的 ignored_modules 參數統一,並且我們可能會很快棄用 ignored_modules。 為了向後相容性,我們同時保留 ignored_statesignored_modules`,但是 FSDP 僅允許將其中之一指定為非 None

  • device_mesh (選用[DeviceMesh]) – DeviceMesh 可以用來替代 process_group。當傳入 device_mesh 時,FSDP 會使用底層的 process groups 進行 all-gather 和 reduce-scatter 集體通訊。因此,這兩個參數需要互斥。對於混合分片策略,例如 ShardingStrategy.HYBRID_SHARD,使用者可以傳入 2D DeviceMesh 而不是 process groups 的元組。對於 2D FSDP + TP,使用者需要傳入 device_mesh 而不是 process_group。有關 DeviceMesh 的更多資訊,請訪問:https://pytorch.dev.org.tw/tutorials/recipes/distributed_device_mesh.html

apply(fn)[原始碼][原始碼]

fn 遞迴地應用於每個子模組(由 .children() 傳回)以及自身。

典型的用法包括初始化模型的參數(另請參閱 torch.nn.init)。

torch.nn.Module.apply 相比,這個版本額外地在應用 fn 之前收集完整參數。不應從另一個 summon_full_params 上下文中調用。

參數

fn (Module -> None) – 要應用於每個子模組的函數

傳回

自身

傳回類型

Module

check_is_root()[原始碼][原始碼]

檢查這個實例是否為 root FSDP 模組。

傳回類型

bool

clip_grad_norm_(max_norm, norm_type=2.0)[原始碼][原始碼]

裁剪所有參數的梯度範數。

範數的計算是將所有參數的梯度視為單個向量,並且梯度會原地修改。

參數
  • max_norm (float or int) – 梯度的最大範數

  • norm_type (float or int) – 使用的 p-norm 的類型。可以是 'inf' 表示無窮範數。

傳回

參數的總範數(視為單個向量)。

傳回類型

Tensor

如果每個 FSDP 實例都使用 NO_SHARD,表示沒有梯度在 ranks 之間分片,那麼您可以直接使用 torch.nn.utils.clip_grad_norm_()

如果至少有一些 FSDP 實例使用分片策略(即,除了 NO_SHARD 之外的策略),那麼您應該使用這個方法而不是 torch.nn.utils.clip_grad_norm_(),因為這個方法會處理梯度在 ranks 之間分片的事實。

傳回的總範數將具有所有參數/梯度中「最大」的 dtype,如 PyTorch 的類型提升語義所定義。 例如,如果所有參數/梯度都使用低精度 dtype,那麼傳回的範數的 dtype 將是該低精度 dtype,但如果存在至少一個使用 FP32 的參數/梯度,那麼傳回的範數的 dtype 將是 FP32。

警告

這需要在所有 ranks 上調用,因為它使用集體通訊。

static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[原始碼][原始碼]

展平一個分片優化器狀態字典。

API 類似於 shard_full_optim_state_dict()。唯一的區別是輸入的 sharded_optim_state_dict 應該從 sharded_optim_state_dict() 傳回。因此,每個 rank 都會有 all-gather 調用來收集 ShardedTensor s。

參數
傳回

參考 shard_full_optim_state_dict()

傳回類型

Dict[str, Any]

forward(*args, **kwargs)[原始碼][原始碼]

執行封裝模組的前向傳遞,插入 FSDP 特有的前向傳遞前後分片邏輯。

傳回類型

Any

static fsdp_modules(module, root_only=False)[原始碼][原始碼]

傳回所有巢狀的 FSDP 實例。

這可能包含 module 本身,且只有在 root_only=True 時才會包含 FSDP 根模組。

參數
  • module (torch.nn.Module) – 根模組,其可能是或可能不是一個 FSDP 模組。

  • root_only (bool) – 是否只傳回 FSDP 根模組。(預設:False

傳回

嵌套在輸入 module 中的 FSDP 模組。

傳回類型

List[FullyShardedDataParallel]

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[原始碼][原始碼]

傳回完整的優化器 state-dict。

在 rank 0 上整合完整的優化器狀態,並依照 torch.optim.Optimizer.state_dict() 的慣例,將其作為一個 dict 傳回,即具有鍵 "state""param_groups"。在 model 中包含的 FSDP 模組中,扁平化的參數會被映射回其未扁平化的參數。

由於它使用集體通訊,因此需要在所有 rank 上調用此方法。但是,如果 rank0_only=True,則 state dict 僅在 rank 0 上填充,並且所有其他 rank 都會傳回一個空的 dict

torch.optim.Optimizer.state_dict() 不同,此方法使用完整的參數名稱作為鍵,而不是參數 ID。

torch.optim.Optimizer.state_dict() 類似,優化器 state dict 中包含的張量不會被克隆,因此可能會出現別名問題。 為了獲得最佳實踐,請考慮立即保存傳回的優化器 state dict,例如使用 torch.save()

參數
  • model (torch.nn.Module) – 根模組(可能是或可能不是一個 FullyShardedDataParallel 實例),其參數已傳遞到優化器 optim 中。

  • optim (torch.optim.Optimizer) – 用於 model 參數的優化器。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳遞到優化器 optim 的輸入,表示參數組的 list 或參數的可迭代物件;如果為 None,則此方法假定輸入為 model.parameters()。 此參數已被棄用,不再需要傳遞它。(預設:None

  • rank0_only (bool) – 如果 True,則僅在 rank 0 上保存已填充的 dict;如果 False,則在所有 rank 上保存它。(預設:True

  • group (dist.ProcessGroup) – 模型的 process group,如果使用預設的 process group,則為 None。(預設:None

傳回

一個 dict,包含 model 原始未扁平化參數的優化器狀態,並遵循 torch.optim.Optimizer.state_dict() 的慣例,包含 “state” 和 “param_groups” 鍵。如果 rank0_only=True,則非零 rank 將返回一個空的 dict

傳回類型

Dict[str, Any]

static get_state_dict_type(module)[source][source]

取得 state_dict_type 以及以 module 為根的 FSDP 模組的相應配置。

目標模組不一定是 FSDP 模組。

傳回

一個 StateDictSettings,包含目前設定的 state_dict_type 和 state_dict / optim_state_dict 配置。

引發
  • 如果不同 FSDP 子模組的 StateDictSettings 不同,則引發 AssertionError。

  • FSDP submodules differ.

傳回類型

StateDictSettings

property module: Module

返回被封裝的模組。

named_buffers(*args, **kwargs)[source][source]

返回一個模組 buffers 的迭代器,產生 buffer 的名稱和 buffer 本身。

攔截 buffer 名稱,並在 summon_full_params() 上下文管理器內部時,移除所有 FSDP 特定的扁平化 buffer 前綴。

傳回類型

Iterator[Tuple[str, Tensor]]

named_parameters(*args, **kwargs)[source][source]

返回一個模組參數的迭代器,產生參數的名稱和參數本身。

攔截參數名稱,並在 summon_full_params() 上下文管理器內部時,移除所有 FSDP 特定的扁平化參數前綴。

傳回類型

Iterator[Tuple[str, Parameter]]

no_sync()[source][source]

停用跨 FSDP 實例的梯度同步。

在這個上下文中,梯度將累積在模組變數中,這些變數將在退出上下文後的第一個前向-後向傳遞中同步。這應僅在根 FSDP 實例上使用,並將遞迴地應用於所有子 FSDP 實例。

注意

這可能會導致更高的記憶體使用量,因為 FSDP 將累積完整的模型梯度(而不是梯度分片),直到最終同步。

注意

當與 CPU 卸載一起使用時,梯度將不會在上下文管理器內部卸載到 CPU。相反,它們只會在最終同步之後立即卸載。

傳回類型

產生器

static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source][source]

轉換與分片模型相對應的優化器的 state-dict。

給定的 state-dict 可以轉換為以下三種類型之一:1) 完整優化器 state_dict,2) 分片優化器 state_dict,3) 本地優化器 state_dict。

對於完整優化器 state_dict,所有狀態都是未扁平化且未分片的。可以通過 state_dict_type() 指定僅 Rank0 和僅 CPU,以避免 OOM。

對於分片優化器 state_dict,所有狀態都是未扁平化但已分片的。可以通過 state_dict_type() 指定僅 CPU,以進一步節省記憶體。

對於本地 state_dict,不會執行任何轉換。但狀態將從 nn.Tensor 轉換為 ShardedTensor,以表示其分片性質(目前尚不支援)。

範例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
參數
  • model (torch.nn.Module) – 根模組(可能是或可能不是一個 FullyShardedDataParallel 實例),其參數已傳遞到優化器 optim 中。

  • optim (torch.optim.Optimizer) – 用於 model 參數的優化器。

  • optim_state_dict (Dict[str, Any]) – 要轉換的目標優化器 state_dict。如果值為 None,則將使用 optim.state_dict()。(預設: None

  • group (dist.ProcessGroup) – 模型所屬的進程群組,參數會在此群組中進行分片;如果使用預設進程群組,則為 None。(預設:None

傳回

包含 model 的最佳化器狀態的 dict。最佳化器狀態的分片基於 state_dict_type

傳回類型

Dict[str, Any]

static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source][source]

轉換最佳化器狀態字典,使其可以載入到與 FSDP 模型關聯的最佳化器中。

給定一個透過 optim_state_dict() 轉換的 optim_state_dict,它會被轉換為可以載入到 optim 的扁平最佳化器狀態字典,而 optimmodel 的最佳化器。model 必須由 FullyShardedDataParallel 進行分片。

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>>     model,
>>>     optim,
>>>     optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
參數
  • model (torch.nn.Module) – 根模組(可能是或可能不是一個 FullyShardedDataParallel 實例),其參數已傳遞到優化器 optim 中。

  • optim (torch.optim.Optimizer) – 用於 model 參數的優化器。

  • optim_state_dict (Dict[str, Any]) – 要載入的最佳化器狀態。

  • is_named_optimizer (bool) – 此最佳化器是 NamedOptimizer 還是 KeyedOptimizer。只有在 optim 是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 時,才設定為 True。

  • load_directly (bool) – 如果設定為 True,此 API 也會在傳回結果之前呼叫 optim.load_state_dict(result)。否則,使用者有責任呼叫 optim.load_state_dict() (預設:False

  • group (dist.ProcessGroup) – 模型所屬的進程群組,參數會在此群組中進行分片;如果使用預設進程群組,則為 None。(預設:None

傳回類型

Dict[str, Any]

register_comm_hook(state, hook)[source][source]

註冊一個通訊鉤子。

這是一個增強功能,為使用者提供靈活的鉤子,使用者可以指定 FSDP 如何跨多個 worker 聚合梯度。此鉤子可用於實現多種算法,例如 GossipGrad 和梯度壓縮,這些算法涉及不同的通訊策略,用於在使用 FullyShardedDataParallel 進行訓練時同步參數。

警告

FSDP 通訊鉤子應在執行初始前向傳遞之前註冊,且只能註冊一次。

參數
  • state (object) –

    傳遞給鉤子以在訓練過程中維護任何狀態資訊。範例包括梯度壓縮中的錯誤回饋,以及 GossipGrad 中下一次要通訊的對等節點等。它由每個 worker 本地儲存,並由 worker 上的所有梯度張量共享。

  • hook (Callable) – 可呼叫對象,具有以下簽名之一:1) hook: Callable[torch.Tensor] -> None:此函數接收一個 Python 張量,該張量表示相對於此 FSDP 單元正在封裝的模型(未被其他 FSDP 子單元封裝)的所有變數的完整、扁平、未分片的梯度。然後它執行所有必要的處理並傳回 None;2) hook: Callable[torch.Tensor, torch.Tensor] -> None:此函數接收兩個 Python 張量,第一個張量表示相對於此 FSDP 單元正在封裝的模型(未被其他 FSDP 子單元封裝)的所有變數的完整、扁平、未分片的梯度。後者表示一個預先調整大小的張量,用於在縮減後儲存分片梯度的區塊。在兩種情況下,可呼叫對象都會執行所有必要的處理並傳回 None。預計簽名 1 的可呼叫對象會處理 NO_SHARD 情況下的梯度通訊。預計簽名 2 的可呼叫對象會處理分片情況下的梯度通訊。

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source][source]

重新鍵入最佳化器狀態字典 optim_state_dict 以使用金鑰類型 optim_state_key_type

這可用於實現來自具有 FSDP 實例的模型和沒有 FSDP 實例的模型之間的最佳化器狀態字典的相容性。

要重新鍵入 FSDP 完整最佳化器狀態字典(即來自 full_optim_state_dict())以使用參數 ID 並可載入到非封裝模型

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

要重新鍵入來自非封裝模型的正常最佳化器狀態字典,使其可載入到封裝模型

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
傳回

使用 optim_state_key_type 指定的參數金鑰重新鍵入的最佳化器狀態字典。

傳回類型

Dict[str, Any]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source][source]

將 rank 0 上的完整最佳化器狀態字典 (optimizer state dict) 分散 (scatter) 到所有其他 ranks。

傳回每個 rank 上分片 (sharded) 的最佳化器狀態字典。 傳回值與 shard_full_optim_state_dict() 相同,且在 rank 0 上,第一個引數應該是 full_optim_state_dict() 的傳回值。

範例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意

shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用來取得分片的最佳化器狀態字典以進行載入。假設完整最佳化器狀態字典駐留在 CPU 記憶體中,前者要求每個 rank 在 CPU 記憶體中都具有完整的字典,每個 rank 單獨對字典進行分片,而無需任何通訊,而後者僅要求 rank 0 在 CPU 記憶體中具有完整的字典,其中 rank 0 將每個分片移動到 GPU 記憶體(對於 NCCL)並將其適當地通訊給各個 ranks。因此,前者具有較高的總體 CPU 記憶體成本,而後者具有較高的通訊成本。

參數
  • full_optim_state_dict (Optional[Dict[str, Any]]) – 最佳化器狀態字典,對應於未展平的參數,如果位於 rank 0 上,則保存完整的非分片最佳化器狀態;此引數在非零 ranks 上會被忽略。

  • model (torch.nn.Module) – 根模組(可能是也可能不是 FullyShardedDataParallel 實例),其參數對應於 full_optim_state_dict 中的最佳化器狀態。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳遞到最佳化器的輸入,表示參數群組的 list 或參數的可迭代對象;如果為 None,則此方法假定輸入為 model.parameters()。此引數已棄用,不再需要傳遞。(預設值:None

  • optim (Optional[torch.optim.Optimizer]) – 將載入此方法傳回的狀態字典的最佳化器。這是比使用 optim_input 更推薦的引數。(預設值:None

  • group (dist.ProcessGroup) – 模型的 process group,如果使用預設的 process group,則為 None。(預設:None

傳回

完整最佳化器狀態字典現在已重新對應到展平的參數,而不是未展平的參數,並且僅限於包含此 rank 的最佳化器狀態部分。

傳回類型

Dict[str, Any]

static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

設定目標模組的所有後代 FSDP 模組的 state_dict_type

還可以取得模型和最佳化器的狀態字典的(可選)配置。目標模組不一定是 FSDP 模組。如果目標模組是 FSDP 模組,則其 state_dict_type 也會被更改。

注意

此 API 應僅針對最上層(根)模組呼叫。

注意

此 API 使使用者能夠透明地使用傳統的 state_dict API 在根 FSDP 模組由另一個 nn.Module 包裝的情況下取得模型檢查點。例如,以下程式碼將確保在所有非 FSDP 實例上呼叫 state_dict,同時將其分派到 FSDP 的 sharded_state_dict 實作中

範例

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
>>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
>>> )
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
參數
  • module (torch.nn.Module) – 根模組。

  • state_dict_type (StateDictType) – 要設定的所需 state_dict_type

  • state_dict_config (選擇性[StateDictConfig]) – 目標 state_dict_type 的設定。

  • optim_state_dict_config (選擇性[OptimStateDictConfig]) – optimizer state dict 的設定。

傳回

包含先前 state_dict 類型和模組設定的 StateDictSettings。

傳回類型

StateDictSettings

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[原始碼][原始碼]

對完整的 optimizer state-dict 進行分片。

full_optim_state_dict 中的狀態重新對應到扁平化的參數,而不是未扁平化的參數,並限制為僅限此 rank 的 optimizer 狀態部分。第一個參數應該是 full_optim_state_dict() 的回傳值。

範例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

注意

shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用來取得分片的最佳化器狀態字典以進行載入。假設完整最佳化器狀態字典駐留在 CPU 記憶體中,前者要求每個 rank 在 CPU 記憶體中都具有完整的字典,每個 rank 單獨對字典進行分片,而無需任何通訊,而後者僅要求 rank 0 在 CPU 記憶體中具有完整的字典,其中 rank 0 將每個分片移動到 GPU 記憶體(對於 NCCL)並將其適當地通訊給各個 ranks。因此,前者具有較高的總體 CPU 記憶體成本,而後者具有較高的通訊成本。

參數
  • full_optim_state_dict (Dict[str, Any]) – Optimizer state dict 對應於未扁平化的參數,並保存完整、未分片的 optimizer 狀態。

  • model (torch.nn.Module) – 根模組(可能是也可能不是 FullyShardedDataParallel 實例),其參數對應於 full_optim_state_dict 中的最佳化器狀態。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳遞到最佳化器的輸入,表示參數群組的 list 或參數的可迭代對象;如果為 None,則此方法假定輸入為 model.parameters()。此引數已棄用,不再需要傳遞。(預設值:None

  • optim (Optional[torch.optim.Optimizer]) – 將載入此方法傳回的狀態字典的最佳化器。這是比使用 optim_input 更推薦的引數。(預設值:None

傳回

完整最佳化器狀態字典現在已重新對應到展平的參數,而不是未展平的參數,並且僅限於包含此 rank 的最佳化器狀態部分。

傳回類型

Dict[str, Any]

static sharded_optim_state_dict(model, optim, group=None)[原始碼][原始碼]

以其分片形式回傳 optimizer state-dict。

此 API 類似於 full_optim_state_dict(),但此 API 將所有非零維度的狀態區塊化為 ShardedTensor 以節省記憶體。只有當模型 state_dict 是在使用 context manager with state_dict_type(SHARDED_STATE_DICT): 導出時,才應使用此 API。

有關詳細用法,請參閱 full_optim_state_dict()

警告

回傳的 state dict 包含 ShardedTensor,不能直接被常規的 optim.load_state_dict 使用。

傳回類型

Dict[str, Any]

static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[原始碼][原始碼]

設定目標模組的所有後代 FSDP 模組的 state_dict_type

此 context manager 具有與 set_state_dict_type() 相同的功能。 請閱讀 set_state_dict_type() 的文件以了解詳細資訊。

範例

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>> ):
>>>     checkpoint = model.state_dict()
參數
  • module (torch.nn.Module) – 根模組。

  • state_dict_type (StateDictType) – 要設定的所需 state_dict_type

  • state_dict_config (選擇性[StateDictConfig]) – 目標 state_dict_type 的模型 state_dict 設定。

  • optim_state_dict_config (選擇性[OptimStateDictConfig]) – 目標 state_dict_type 的 optimizer state_dict 設定。

傳回類型

產生器

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source][source]

利用此情境管理器 (context manager) 公開 FSDP 實例的完整參數 (full params)。

在模型進行前向 (forward)/反向 (backward) 之後,這對於取得參數以進行額外的處理或檢查很有用。它可以接受一個非 FSDP 模組,並且會根據 recurse 參數,為所有包含的 FSDP 模組及其子模組調用 (summon) 完整參數。

注意

這可以用於內部的 FSDP。

注意

不能在正向或反向傳遞中使用。正向和反向傳遞也不能從此情境中開始。

注意

參數將在情境管理器退出後還原為它們的本地分片 (local shards),儲存行為與正向傳遞相同。

注意

可以修改完整的參數,但只有對應於本地參數分片的部分會在情境管理器退出後持續存在 (除非 writeback=False,在這種情況下,變更將被丟棄)。在 FSDP 不對參數進行分片的情況下,目前僅在 world_size == 1NO_SHARD 配置時,無論 writeback 如何,修改都會持續存在。

注意

此方法適用於本身不是 FSDP 但可能包含多個獨立 FSDP 單元的模組。在這種情況下,給定的參數將應用於所有包含的 FSDP 單元。

警告

請注意,目前不支援與 writeback=True 結合使用的 rank0_only=True,並且會引發錯誤。這是因為模型參數形狀在情境中的各個 rank 之間會有所不同,並且寫入它們可能會導致情境退出時各個 rank 之間的不一致。

警告

請注意,offload_to_cpurank0_only=False 會導致完整參數被冗餘地複製到位於同一機器上的 GPU 的 CPU 記憶體中,這可能會導致 CPU OOM 的風險。建議將 offload_to_cpurank0_only=True 一起使用。

參數
  • recurse (bool, Optional) – 遞迴地為巢狀 FSDP 實例調用所有參數 (預設:True)。

  • writeback (bool, Optional) – 如果 False,則在情境管理器退出後,對參數的修改將被丟棄;停用此功能可能會稍微提高效率 (預設:True)

  • rank0_only (bool, Optional) – 如果 True,則僅在全域 rank 0 上實現完整參數。這意味著在情境中,只有 rank 0 具有完整參數,而其他 rank 將具有分片參數。請注意,不支援將 rank0_only=Truewriteback=True 一起設定,因為模型參數形狀在情境中的各個 rank 之間會有所不同,並且寫入它們可能會導致情境退出時各個 rank 之間的不一致。

  • offload_to_cpu (bool, Optional) – 如果 True,則完整參數會卸載到 CPU。請注意,目前只有在參數被分片時才會發生這種卸載 (只有在 world_size = 1 或 NO_SHARD 配置的情況下才不是這樣)。建議將 offload_to_cpurank0_only=True 一起使用,以避免將模型參數的冗餘副本卸載到相同的 CPU 記憶體。

  • with_grads (bool, Optional) – 如果 True,梯度也會與參數一起取消分片。目前,僅當將 use_orig_params=True 傳遞給 FSDP 建構函式,並將 offload_to_cpu=False 傳遞給此方法時,才支援此功能。(預設:False)

傳回類型

產生器

class torch.distributed.fsdp.BackwardPrefetch(value)[source][source]

此設定配置顯式的反向預取 (backward prefetching),這透過在反向傳遞中啟用通訊和計算重疊來提高吞吐量,但會稍微增加記憶體使用量。

  • BACKWARD_PRE:這可以實現最多的重疊,但會最大程度地增加記憶體使用量。它會在計算目前參數集的梯度之前,預取下一組參數。這會重疊下一個 all-gather目前的梯度計算,並且在峰值時,它會在記憶體中保存目前的一組參數、下一組參數和目前的梯度集。

  • BACKWARD_POST:這個選項可以減少重疊,但需要的記憶體也較少。它會在目前參數集的梯度計算完成後,預先載入下一組參數。這樣做可以重疊目前的 reduce-scatter下一次的梯度計算,並且在為下一組參數分配記憶體之前,釋放目前的參數集。在記憶體使用高峰時,只會持有下一組參數和目前一組的梯度。

  • FSDP 的 backward_prefetch 參數接受 None,這會完全停用反向預先載入。這樣做不會有重疊,也不會增加記憶體使用量。一般來說,我們不建議使用此設定,因為它可能會大幅降低吞吐量。

更多技術背景:對於使用 NCCL 後端的單個進程組,任何集合操作(即使從不同的 stream 發出)都會競爭相同的每個裝置的 NCCL stream,這意味著發出集合操作的相對順序對於重疊很重要。兩個反向預先載入的值對應於不同的發出順序。

class torch.distributed.fsdp.ShardingStrategy(value)[source][source]

這指定了 FullyShardedDataParallel 用於分散式訓練的分片策略。

  • FULL_SHARD:參數、梯度和優化器狀態都會被分片。對於參數,此策略會在前向傳播之前取消分片(透過 all-gather),在前向傳播之後重新分片,在反向計算之前取消分片,在反向計算之後重新分片。對於梯度,它會在反向計算之後同步並分片它們(透過 reduce-scatter)。分片的優化器狀態會在每個 rank 本地更新。

  • SHARD_GRAD_OP:梯度和優化器狀態在計算期間被分片,此外,參數在計算之外被分片。對於參數,此策略在前向傳播之前取消分片,在前向傳播之後不重新分片它們,並且僅在反向計算之後重新分片它們。分片的優化器狀態會在每個 rank 本地更新。在 no_sync() 內部,參數在反向計算後不會被重新分片。

  • NO_SHARD:參數、梯度和優化器狀態不會被分片,而是跨 rank 複製,類似於 PyTorch 的 DistributedDataParallel API。對於梯度,此策略會在反向計算之後同步它們(透過 all-reduce)。未分片的優化器狀態會在每個 rank 本地更新。

  • HYBRID_SHARD:在一個節點內應用 FULL_SHARD,並跨節點複製參數。這樣做可以減少通信量,因為昂貴的 all-gather 和 reduce-scatter 僅在一個節點內完成,這對於中型模型可能更有效。

  • _HYBRID_SHARD_ZERO2:在一個節點內應用 SHARD_GRAD_OP,並跨節點複製參數。這就像 HYBRID_SHARD,除了這可能提供更高的吞吐量,因為未分片的參數在前向傳遞之後不會被釋放,從而節省了反向傳播之前的 all-gather。

class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source][source]

這配置了 FSDP 原生的混合精度訓練。

變數
  • param_dtype (Optional[torch.dtype]) – 這指定了模型參數在前向和反向傳播期間的 dtype,因此也是前向和反向計算的 dtype。在前向和反向傳播之外,分片的參數保持完整精度(例如,用於優化器步驟),並且對於模型檢查點,參數始終以完整精度保存。(預設值:None

  • reduce_dtype (Optional[torch.dtype]) – 這指定了梯度縮減的 dtype(即 reduce-scatter 或 all-reduce)。如果這是 Noneparam_dtype 不是 None,則這將採用 param_dtype 值,仍然以低精度運行梯度縮減。允許這與 param_dtype 不同,例如,強制梯度縮減以全精度運行。(預設值:None

  • buffer_dtype (Optional[torch.dtype]) – 這指定了 buffer 的 dtype。FSDP 不會分片 buffer。相反,FSDP 在第一個前向傳播中將它們轉換為 buffer_dtype,然後將它們保持在該 dtype 中。對於模型檢查點,buffer 以完整精度保存,除了 LOCAL_STATE_DICT。(預設值:None

  • keep_low_precision_grads (bool) – 如果 False,則 FSDP 會在反向傳播後將梯度向上轉換為完整精度,以準備進行優化器步驟。如果 True,則 FSDP 會將梯度保留在用於梯度縮減的 dtype 中,如果使用支援低精度運行的自定義優化器,則可以節省記憶體。(預設值:False

  • cast_forward_inputs (bool) – 如果 True,則此 FSDP 模組會將其 forward args 和 kwargs 轉換為 param_dtype。 這是為了確保參數和輸入 dtypes 與 forward 計算相符,這是許多運算所必需的。 當僅將混合精度應用於部分而非所有 FSDP 模組時,可能需要將此設定為 True,在這種情況下,混合精度 FSDP 子模組需要重新轉換其輸入。(預設值:False

  • cast_root_forward_inputs (bool) – 如果 True,則 root FSDP 模組會將其 forward args 和 kwargs 轉換為 param_dtype,從而覆寫 cast_forward_inputs 的值。 對於非 root FSDP 模組,這不會執行任何操作。(預設值:True

  • _module_classes_to_ignore (Sequence[Type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 這指定了在使用 auto_wrap_policy 時要忽略混合精度的模組類別:這些類別的模組將分別應用 FSDP,並禁用混合精度(這意味著最終的 FSDP 構造將偏離指定的策略)。 如果未指定 auto_wrap_policy,則這不會執行任何操作。 此 API 是實驗性的,可能會有所變更。(預設值:(_BatchNorm,)

注意

此 API 是實驗性的,可能會有所變更。

注意

只有浮點張量才會轉換為其指定的 dtypes。

注意

summon_full_params 中,參數會強制使用完整精度,但緩衝區則不會。

注意

Layer norm 和 batch norm 會以 float32 累積,即使其輸入採用低精度(例如 float16bfloat16)也是如此。 針對這些 norm 模組停用 FSDP 的混合精度僅表示仿射參數會保留在 float32 中。 但是,這會導致這些 norm 模組出現單獨的 all-gather 和 reduce-scatter,這可能效率低下,因此,如果工作負載允許,使用者應優先將混合精度應用於這些模組。

注意

預設情況下,如果使用者傳遞一個包含任何 _BatchNorm 模組的模型,並指定 auto_wrap_policy,則 batch norm 模組將分別應用 FSDP,並停用混合精度。 請參閱 _module_classes_to_ignore 參數。

注意

MixedPrecision 預設具有 cast_root_forward_inputs=Truecast_forward_inputs=False。 對於 root FSDP 實例,其 cast_root_forward_inputs 優先於其 cast_forward_inputs。 對於非 root FSDP 實例,其 cast_root_forward_inputs 值會被忽略。 預設設定足以應付每個 FSDP 實例都具有相同 MixedPrecision 配置,且僅需要在模型 forward 傳遞開始時將輸入轉換為 param_dtype 的典型情況。

注意

對於具有不同 MixedPrecision 配置的巢狀 FSDP 實例,我們建議設定個別 cast_forward_inputs 值,以在每個實例的 forward 傳遞之前配置是否轉換輸入。 在這種情況下,由於轉換發生在每個 FSDP 實例的 forward 傳遞之前,因此父 FSDP 實例應使其非 FSDP 子模組在其 FSDP 子模組之前運行,以避免由於不同的 MixedPrecision 配置而導致啟動 dtype 發生變更。

範例

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1],
>>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>>     model,
>>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )

上述範例展示了一個可運作的範例。 另一方面,如果將 model[1] 替換為 model[0],這表示使用不同 MixedPrecision 的子模組首先運行其 forward 傳遞,則 model[1] 將錯誤地看到 float16 啟動,而不是 bfloat16 啟動。

class torch.distributed.fsdp.CPUOffload(offload_params=False)[source][source]

這會配置 CPU 卸載。

變數

offload_params (bool) – 這指定在不涉及計算時是否將參數卸載到 CPU。 如果 True,則這也會將梯度卸載到 CPU,這表示優化器步驟在 CPU 上運行。

class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source][source]

StateDictConfig 是所有 state_dict 組態類別的基底類別。 使用者應實例化一個子類別(例如 FullStateDictConfig),以便設定 FSDP 支援的對應 state_dict 類型的設定。

變數

offload_to_cpu (bool) – 如果 True,則 FSDP 會將 state dict 的數值卸載 (offload) 到 CPU;如果 False,則 FSDP 會將它們保留在 GPU 上。(預設值:False

class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source][source]

FullStateDictConfig 是一個配置類別,旨在與 StateDictType.FULL_STATE_DICT 一起使用。 我們建議啟用 offload_to_cpu=Truerank0_only=True,以便在儲存完整的 state dict 時分別節省 GPU 記憶體和 CPU 記憶體。 此配置類別旨在透過 state_dict_type() 上下文管理器使用,如下所示

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>>     state = fsdp.state_dict()
>>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>>     # Load checkpoint only on rank 0 to avoid memory redundancy
>>>     state_dict = torch.load("my_checkpoint.pt")
>>>     model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
變數

rank0_only (bool) – 如果 True,則只有 rank 0 會儲存完整的 state dict,而其他 rank 則儲存空的 dict。 如果 False,則所有 rank 都會儲存完整的 state dict。(預設值:False

class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source][source]

ShardedStateDictConfig 是一個配置類別,旨在與 StateDictType.SHARDED_STATE_DICT 一起使用。

變數

_use_dtensor (bool) – 如果 True,則 FSDP 會將 state dict 的數值儲存為 DTensor;如果 False,則 FSDP 會將它們儲存為 ShardedTensor。(預設值:False

警告

_use_dtensorShardedStateDictConfig 的一個私有欄位,FSDP 使用它來判斷 state dict 數值的類型。 使用者不應手動修改 _use_dtensor

class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)[source][source]
class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source][source]

OptimStateDictConfig 是所有 optim_state_dict 配置類別的基底類別。 使用者應實例化子類別 (例如 FullOptimStateDictConfig) ,以便為 FSDP 支援的相應 optim_state_dict 類型配置設定。

變數

offload_to_cpu (bool) – 如果 True,則 FSDP 會將 state dict 的張量數值卸載到 CPU;如果 False,則 FSDP 會將它們保留在原始裝置上 (除非啟用了參數 CPU 卸載,否則為 GPU)。(預設值:True

class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[原始碼][原始碼]
變數

rank0_only (bool) – 如果 True,則只有 rank 0 會儲存完整的 state dict,而其他 rank 則儲存空的 dict。 如果 False,則所有 rank 都會儲存完整的 state dict。(預設值:False

class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[原始碼][原始碼]

ShardedOptimStateDictConfig 是一個組態類別,旨在與 StateDictType.SHARDED_STATE_DICT 一起使用。

變數

_use_dtensor (bool) – 如果 True,則 FSDP 會將 state dict 的數值儲存為 DTensor;如果 False,則 FSDP 會將它們儲存為 ShardedTensor。(預設值:False

警告

_use_dtensorShardedOptimStateDictConfig 的私有欄位,FSDP 使用它來決定 state dict 值的類型。使用者不應手動修改 _use_dtensor

class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[原始碼][原始碼]
class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[原始碼][原始碼]

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源