分散式檢查點 - torch.distributed.checkpoint¶
分散式檢查點 (DCP) 支援從多個 rank 並行載入和儲存模型。它處理載入時的重新分片,從而能夠在一個叢集拓撲中儲存並載入到另一個叢集拓撲中。
DCP 與 torch.save 和 torch.load 在幾個重要方面有所不同
它為每個檢查點生成多個檔案,每個 rank 至少一個。
它以就地方式運作,這意味著模型應首先分配其數據,然後 DCP 使用該儲存空間。
載入和儲存檢查點的進入點如下
其他資源:¶
- torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[原始碼][原始碼]¶
以 SPMD 樣式儲存分散式模型。
此函數與
torch.save()
不同,因為它透過讓每個 rank 僅儲存其本地分片來處理ShardedTensor
和DTensor
。對於每個
Stateful
物件(同時具有state_dict
和load_state_dict
),save 將在序列化之前呼叫state_dict
。警告
無法保證跨 PyTorch 版本儲存的 state_dict 的向後相容性。
警告
如果使用 process_group 參數,請確保只有其 ranks 呼叫 save_state_dict 且 state_dict 中的所有資料都屬於它。
注意
當為 FSDP 的 ShardingStrategy.HYBRID_SHARD 儲存檢查點時,只有其中一個 shard_group 應該呼叫 save_state_dict,並且需要傳入對應的 process group。
注意
- 如果沒有可用的 process group,此函數會假設目的是儲存
本地進程中的 state_dict。
- 參數
state_dict (Dict[str, Any]) – 要儲存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點實例的 ID。 checkpoint_id 的含義取決於儲存體。它可以是資料夾或檔案的路徑。如果儲存體是鍵值儲存,它也可以是鍵。(預設值:
None
)storage_writer (Optional[StorageWriter]) – 用於執行寫入的 StorageWriter 實例。如果未指定,DCP 將根據 checkpoint_id 自動推斷 writer。如果 checkpoint_id 也為 None,將引發例外。(預設值:
None
)planner (Optional[SavePlanner]) – SavePlanner 實例。如果未指定,將使用預設的 planner。(預設值:
None
)process_group (Optional[ProcessGroup]) – 用於跨 rank 同步的 ProcessGroup。(預設值:
None
)
- 傳回
已儲存檢查點的中繼資料物件。
- 傳回類型
Metadata
範例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> )
注意
save_state_dict 使用 collectives 來協調跨 ranks 的寫入。對於基於 NCCL 的 process groups,物件的內部張量表示必須在通訊發生之前移動到 GPU 裝置。在這種情況下,使用的裝置由
torch.cuda.current_device()
给出,使用者有責任確保設定此裝置,以便每個 rank 都有一個單獨的 GPU,透過torch.cuda.set_device()
。
- torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[原始碼][原始碼]¶
save
的非同步版本。此程式碼首先將 state_dict 從記憶體中取出到暫存儲存體(預設為 CPU 記憶體),然後在單獨的執行緒中呼叫 save。警告
此功能為實驗性功能,可能會發生變更。
- 參數
state_dict (Dict[str, Any]) – 要儲存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點實例的 ID。 checkpoint_id 的含義取決於儲存體。它可以是資料夾或檔案的路徑。如果儲存體是鍵值儲存,它也可以是鍵。(預設值:
None
)storage_writer (Optional[StorageWriter]) – 用於執行 'stage' 和 'save' 的 StorageWriter 實例。如果未指定,DCP 將根據 checkpoint_id 自動推斷 writer。如果 checkpoint_id 也為 None,將引發例外。(預設值:
None
)planner (Optional[SavePlanner]) – SavePlanner 實例。如果未指定,將使用預設的 planner。(預設值:
None
)process_group (Optional[ProcessGroup]) – 用於跨 rank 同步的 ProcessGroup。(預設值:
None
)
- 傳回
包含來自 save 的結果 Metadata 物件的 future。
- 傳回類型
範例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) >>> >>> # ... do some work ... >>> >>> checkpoint_future.result()
- torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]¶
這個方法已過時。請改用 'save'。
- 傳回類型
Metadata
- torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[source][source]¶
以 SPMD 樣式載入分散式
state_dict
。每個 rank 都會嘗試讀取滿足請求的 state_dict 所需的最少量資料。當載入
ShardedTensor
或DTensor
實例時,每個 rank 只會讀取其本地 shard 的資料。對於每個
Stateful
物件 (同時具有state_dict
和load_state_dict
),load 會先呼叫state_dict
,然後嘗試還原序列化,接著在還原序列化完成後呼叫load_state_dict
。對於每個非Stateful
物件,load 將會還原序列化物件,然後在state_dict
中使用還原序列化的物件替換它。警告
state_dict
中的所有 tensor 必須在呼叫此函式之前分配到其目標裝置上。所有非 tensor 資料都使用 torch.load() 載入,並在 state_dict 上就地修改。
警告
使用者必須在根模組上呼叫 load_state_dict,以確保載入後處理和非 tensor 資料正確傳播。
- 參數
state_dict (Dict[str, Any]) – 要儲存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點實例的 ID。 checkpoint_id 的含義取決於儲存體。它可以是資料夾或檔案的路徑。如果儲存體是鍵值儲存,它也可以是鍵。(預設值:
None
)storage_reader (Optional[StorageReader]) – 用於執行讀取的 StorageWriter 實例。如果未指定,DCP 將根據 checkpoint_id 自動推斷讀取器。如果 checkpoint_id 也為 None,則會引發例外。(預設:
None
)planner (Optional[LoadPlanner]) – LoadPlanner 實例。如果未指定,將使用預設的 planner。(預設:
None
)process_group (Optional[ProcessGroup]) – 用於跨 rank 同步的 ProcessGroup。(預設值:
None
)
- 傳回
None.
- 傳回類型
None
- 範例
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
注意
load_state_dict 使用集合來協調跨 rank 的讀取。對於基於 NCCL 的 process group,物件的內部 tensor 表示必須先移動到 GPU 裝置,然後才能進行通訊。在這種情況下,使用的裝置由
torch.cuda.current_device()
給出,使用者有責任確保已設定此值,以便每個 rank 都擁有一個單獨的 GPU,透過torch.cuda.set_device()
。
- torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]¶
這個方法已過時。請改用 'load'。
以下模組也可用於額外客製化用於非同步 checkpointing 的分期機制 (torch.distributed.checkpoint.async_save)
- class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source][source]¶
此協定旨在為 dcp.async_save 提供客製化和擴展性,允許使用者自定義在平行執行通常的 dcp.save 路徑之前,如何暫存資料。預期的操作順序(具體定義於 torch.distributed.state_dict_saver.async_save)如下:
- AsyncStager.stage_data(state_dict)
此呼叫讓 AsyncStager 有機會「暫存」state_dict。在此上下文中,暫存的期望和目的是創建 state dict 的「訓練安全」表示形式,這意味著在暫存完成後對模組資料的任何更新都不應反映在此方法返回的 state dict 中。例如,在預設情況下,會在 CPU RAM 上建立整個 state dict 的副本並在此處返回,允許使用者繼續訓練,而不會冒著資料變更的風險,因為資料正在序列化。
- 在平行狀態下,對從 stage 返回的 state_dict 呼叫 dcp.save。此呼叫負責
序列化 state_dict 並將其寫入儲存空間。
- 如果 AsyncStager.should_synchronize_after_execute 為 True,則此方法將在
序列化執行緒啟動後立即呼叫,並在從 dcp.async_save 返回之前呼叫。如果此值設定為 False,則假設使用者已定義一個自定義同步點,以進一步優化訓練迴圈中的儲存延遲(例如,透過將暫存與正向/反向傳遞重疊),並且使用者有責任在適當的時間呼叫 AsyncStager.synchronize_staging。
- class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source][source]¶
AsyncStager 的一種實現,可在 CPU RAM 上暫存 state_dict 並阻止,直到複製完成。此實現還提供了一個選項,可以使用釘選記憶體來優化 stage 延遲。
注意:在這種情況下,synchronize_staging 是一個空操作。
除了上述進入點之外,如下所述的 Stateful 物件還在儲存/載入期間提供額外的自定義 .. automodule:: torch.distributed.checkpoint.stateful
- class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source][source]¶
用於可檢查點和還原物件的 Statefull 協議。
此範例示範如何使用 Pytorch Distributed Checkpoint 來儲存 FSDP 模型。
以下類型定義檢查點期間使用的 IO 介面
- class torch.distributed.checkpoint.StorageReader[source][source]¶
load_state_dict
用來從儲存空間讀取的介面。一個 StorageReader 實例同時充當分散式檢查點中的協調器和追隨者。作為初始化的一部分,每個實例都會被告知其角色。
子類別應預期
load_state_dict
進行以下呼叫序列(所有 rank) 如果使用者傳遞有效的 checkpoint_id,則設定 checkpoint_id。
(所有 rank) read_metadata()
(所有 rank) set_up_storage_reader()
(所有 rank) prepare_local_plan()
(協調器) prepare_global_plan()
(所有 rank) read_data()
- abstract prepare_global_plan(plans)[source][source]¶
執行儲存空間載入的集中式規劃。
此方法僅在協調器實例上呼叫。
雖然此方法可以產生完全不同的計畫,但首選方法是在 LoadPlan::storage_data 中儲存儲存空間特定資料。
- abstract prepare_local_plan(plan)[source][source]¶
執行儲存空間特定的本地規劃。
雖然此方法可以產生完全不同的計畫,但建議的方法是在 LoadPlan::storage_data 中儲存儲存空間特定資料。
- abstract read_data(plan, planner)[source][source]¶
使用
planner
解析資料,從plan
讀取所有項目。子類別應呼叫
LoadPlanner::load_bytes
以將 BytesIO 物件反序列化到正確的位置。子類別應呼叫
LoadPlanner::resolve_tensor
以存取它應將資料載入到的張量。StorageLayer 負責正確排程所需的任何跨裝置複製。
- 參數
plan (LoadPlan) – 要執行的本機計畫
planner (LoadPlanner) – 用於解析項目的 planner 物件。
- 傳回
一個 future,在所有讀取完成後完成。
- 傳回類型
Future[None]
- abstract reset(checkpoint_id=None)[source][source]¶
表示即將進行全新的檢查點讀取。如果使用者為此檢查點讀取設定了 checkpoint_id,則可能會出現 checkpoint_id。 checkpiont_id 的含義取決於儲存。它可以是資料夾/檔案的路徑,也可以是鍵值儲存的鍵。
- 參數
checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點實例的 ID。 checkpoint_id 的含義取決於儲存。它可以是資料夾或檔案的路徑。如果儲存更像是鍵值儲存,它也可以是一個鍵。(預設值:
None
)
- class torch.distributed.checkpoint.StorageWriter[source][source]¶
save_state_dict
用於寫入儲存的介面。一個 StorageWriter 實例在分散式檢查點中同時充當協調器和追隨者。 作為初始化的一部分,每個實例都會被告知其角色。
一個子類別應該預期以下呼叫順序。
(所有 rank) 如果使用者傳遞有效的 checkpoint_id,則設定 checkpoint_id。
(所有節點) set_up_storage_writer()
(所有 rank) prepare_local_plan()
(協調器) prepare_global_plan()
(所有節點) write_data()
(協調器) finish()
- abstract finish(metadata, results)[source][source]¶
寫入元數據並將當前檢查點標記為成功。
用於序列化 metadata 的實際格式/架構是一個實作細節。 唯一的要求是它可以恢復到相同的物件圖。
- abstract prepare_global_plan(plans)[source][source]¶
執行儲存的集中規劃。
此方法僅在協調器實例上呼叫。
雖然此方法可以產生完全不同的計畫,但首選方法是在 SavePlan::storage_data 中儲存儲存特定資料。
- abstract prepare_local_plan(plan)[source][source]¶
執行儲存空間特定的本地規劃。
雖然此方法可以產生完全不同的計畫,但建議的方式是將儲存空間特定的資料儲存在 SavePlan::storage_data 中。
- abstract reset(checkpoint_id=None)[source][source]¶
調用表示即將進行全新的檢查點寫入。如果使用者為此檢查點寫入設定了 checkpoint_id,則可能存在 checkpoint_id。checkpiont_id 的含義取決於儲存空間。它可以是資料夾/檔案的路徑,或是鍵值儲存空間的鍵。
- 參數
checkpoint_id (Union[str, os.PathLike, None]) – 此檢查點實例的 ID。 checkpoint_id 的含義取決於儲存體。它可以是資料夾或檔案的路徑。如果儲存體是鍵值儲存,它也可以是鍵。(預設值:
None
)
- abstract set_up_storage_writer(is_coordinator)[source][source]¶
初始化此實例。
- 參數
is_coordinator (bool) – 此實例是否負責協調檢查點。
- storage_meta()[source][source]¶
傳回儲存空間特定的元資料。這用於在檢查點中儲存額外資訊,這對於提供請求層級的可觀察性很有用。 StorageMeta 在儲存調用期間傳遞給
SavePlanner
。預設傳回 None。待辦事項:提供範例
- 傳回類型
Optional[StorageMeta]
- abstract classmethod validate_checkpoint_id(checkpoint_id)[source][source]¶
檢查給定的 checkpoint_id 是否受儲存支援。 這讓我們能夠啟用自動儲存選擇。
- 傳回類型
- abstract write_data(plan, planner)[source][source]¶
使用
planner
解析資料,從plan
寫入所有項目。子類別應調用
SavePlanner::resolve_data
以取得計畫中每個項目的底層物件的存取權,以便進行寫入。子類別應延遲調用 resolve_data,因為它可能會分配記憶體。對於張量,請做出以下假設
它們可能位於任何裝置上,包括與
WriteItem::tensor_data
上的裝置不符的裝置它們可能是或可能不是檢視或不連續的。只需要儲存投影。
- 參數
plan (SavePlan) – 要執行的儲存計畫。
planner (SavePlanner) – 用於將項目解析為資料的規劃器物件。
- 傳回
完成 WriteResult 清單的 future
- 傳回類型
以下類型定義了檢查點期間使用的規劃器介面
- class torch.distributed.checkpoint.LoadPlanner[source][source]¶
定義 load_state_dict 用於規劃載入過程的協定的抽象類別。
LoadPlanner 是有狀態的物件,可用於自訂整個載入過程。
LoadPlanner 充當 state_dict 的存取 Proxy,因此對其進行的任何轉換對整個過程都是可見的。
規劃器子類別可以預期在 load_state_dict 期間進行以下調用順序
- set_up_planner - 在所有 rank 上調用。
表示開始載入檢查點。
- create_local_plan - 在所有 rank 上調用。
處理 state_dict 並產生將被傳送以進行全域規劃的 LoadPlan。
- create_global_plan - 僅在協調器 rank 上調用。
採用來自所有 rank 的 LoadPlan 並做出任何全域決策。
- load_bytes - 在每個 rank 上多次調用
這會針對 state_dict 中的每個非張量值調用一次。
- resolve_tensor 和 commit_tensor - 在每個 rank 上多次調用
它們針對 state_dict 中的每個 Tensor 值以成對方式調用。
建議使用者擴展 DefaultLoadPlanner 而不是直接擴展此介面,因為大多數變更都可以透過單個方法中的變更來表達。
有兩種常見的擴展模式
重寫 state_dict。這是擴展載入過程的最簡單方法,因為它不需要了解 LoadPlan 如何運作的複雜性。我們需要保留對原始 state_dict 的參考,因為載入是就地發生的,因此我們需要能夠就地執行它
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> metadata: Metadata, >>> is_coordinator: bool, >>> ) -> None: >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
修改 resolve_tensor 和 commit_tensor 以處理載入時間轉換。
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- abstract commit_tensor(read_item, tensor)[source][source]¶
當 StorageReader 完成將資料載入到
tensor
後呼叫。提供的 tensor 與呼叫
resolve_tensor
返回的 tensor 相同。 只有當此 LoadPlanner 需要在將tensor
複製回 state_dict 中的 tensor 之前對其進行後處理時,才需要此方法。tensor 的內容將遵循其裝置同步模型。
- abstract create_global_plan(global_plan)[source][source]¶
計算全域載入計畫並傳回每個 rank 的計畫。
注意:這僅在協調器 rank 上呼叫
- abstract create_local_plan()[source][source]¶
根據 set_up_planner 提供的 state_dict 和 metadata 建立 LoadPlan。
注意:這在每個 rank 上呼叫。
- 傳回類型
- abstract load_bytes(read_item, value)[source][source]¶
載入由
read_item``和 ``value
描述的項目。預期此方法會就地修改底層的 state_dict。
value
的內容由用於產生正在載入的檢查點的 SavePlanner 定義。
- resolve_bytes(read_item)[source][source]¶
傳回 StorageReader 用於載入 read_item 的 BytesIO。
BytesIO 應與底層 state_dict 上的 BytesIO 建立別名,因為 StorageReader 將替換其內容。
- 傳回類型
BytesIO
- class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data<
- class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source][source]¶
- class torch.distributed.checkpoint.SavePlanner[source][source]¶
Abstract class defining the protocol used by save_state_dict to plan the save process.
SavePlanners are stateful objects that can be used to customize the whole save process.
SavePlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.
A planner subclass can expect the following sequence of calls during save_state_dict
- set_up_planner - 在所有 rank 上調用。
Signals the start of a checkpoint save.
- create_local_plan - 在所有 rank 上調用。
Process the state_dict and produces a SavePlan that will be sent for global planning.
- create_global_plan - 僅在協調器 rank 上調用。
Takes the SavePlan from all ranks and make any global decision.
- finish_plan - called on all ranks.
This gives each rank a chance to adjust to global planning decisions.
- resolve_data - called multiple times on each rank
Lookups a value on the state_dict for the storage layer to write.
Users are recommended to extend DefaultSavePlanner instead of this interface directly as most changes can be expressed by changes in a single method.
There are 3 usual patterns of extension
Rewriting state_dict. This is the simplest way to extend the save process as it doesn’t requite understanding the intrincacies of how SavePlan works
>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> storage_meta: Optional[StorageMeta], >>> is_coordinator: bool, >>> ) -> None: >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
Using the global planning step to make central decisions that can’t be made individually by each rank
>>> from itertools import zip_longest >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> iters = [iter(all_plans[0].items)] * len(all_plans) >>> items_per_rank = [ >>> [item for item in items if item is not None] >>> for items in zip(*zip_longest(*iters), strict=True) >>> ] >>> all_plans = [ >>> replace(plan, items=items) >>> for plan, items in zip(all_plans, items_per_rank, strict=True) >>> ] >>> return super().create_global_plan(all_plans)
Finally, some planners need to save additional metadata in the checkpoint, this is accomplished by having each rank contribute their data items in the local plan and the global planner aggregate them
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- abstract create_global_plan(all_plans)[source][source]¶
Compute the global checkpoint plan and return the local plan of each rank.
This is called on the coordinator rank only.
- abstract create_local_plan()[source][source]¶
Compute the save plan for the current rank.
This will be aggregated and passed to create_global_plan. Planner specific data can be passed through SavePlan::planner_data.
This is called on all ranks.
- 傳回類型
- abstract finish_plan(new_plan)[source][source]¶
合併由 create_local_plan 建立的計畫,以及 create_global_plan 的結果。
This is called on all ranks.
- 傳回類型
- abstract resolve_data(write_item)[source][source]¶
轉換並準備來自
state_dict
的write_item
以進行儲存,確保等冪性和線程安全性。在
state_dict
中查找與write_item
關聯的物件,並在儲存層使用它之前應用任何轉換(例如序列化)。在每個 rank 上多次調用,最終 SavePlan 中每個 WriteItem 至少調用一次。
此方法應具有等冪性和線程安全性。StorageWriter 實作可以根據需要經常調用它。
任何分配記憶體的轉換都應該在此方法被調用時延遲執行,以減少 checkpointing 所需的峰值記憶體。
當返回 tensors 時,它們可以在任何設備或格式上,它們也可以是 views。儲存層有責任弄清楚如何儲存它們。
- class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source][source]¶
- class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source][source]¶
資料類別,用於保存有關需要寫入儲存體的資訊。
我們提供基於檔案系統的儲存層
- class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True)[source][source]¶
使用檔案 I/O 的 StorageWriter 基本實作。
此實作進行以下假設和簡化:
檢查點路徑是一個空的或不存在的目錄。
檔案建立是原子性的。
檢查點由每個寫入請求一個檔案,加上一個包含序列化元資料的 .metadata 檔案組成。
我們提供 LoadPlanner 和 SavePlanner 的預設實作,可以處理所有 torch.distributed 的結構,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。
- class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False)[source][source]¶
- class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]¶
DefaultLoadPlanner 在 LoadPlanner 之上新增多種功能。
特別是,它新增以下功能:
flatten_state_dict:處理具有巢狀字典的 state_dict。flatten_sharded_tensors:針對 2D 平行模式中的 FSDP。allow_partial_load:如果為 False,如果 state_dict 中存在一個 key,但在檢查點中不存在,將引發執行階段錯誤。
由於遺留的設計決策,即使原始的未平行化模型相同,FSDP 和 DDP 的狀態字典可能具有不同的鍵或完整名稱 (例如,layer1.weight)。此外,FSDP 提供各種型別的模型狀態字典,例如完整和分片狀態字典。此外,最佳化器狀態字典使用參數 ID 而不是完整名稱來識別參數,這可能會在使用平行化時 (例如,管線平行化) 導致問題。
為了應對這些挑戰,我們提供了一系列的 API,讓使用者能夠輕鬆管理 state_dicts。 get_model_state_dict 會回傳一個模型 state dictionary,其鍵值與未平行化模型 state dictionary 回傳的鍵值一致。 類似地,get_optimizer_state_dict 會提供最佳化器 state dictionary,其鍵值在所有平行化應用中都是一致的。 為了實現這種一致性,get_optimizer_state_dict 會將參數 ID 轉換為完全合格的名稱 (fully qualified names),這些名稱與在未平行化模型 state dictionary 中找到的名稱相同。
請注意,這些 API 回傳的結果可以直接與 torch.distributed.checkpoint.save() 和 torch.distributed.checkpoint.load() 方法一起使用,而無需任何額外的轉換。
請注意,此功能為實驗性功能,API 簽章未來可能會變更。
- torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]¶
回傳模型 state_dict 和最佳化器 state_dict。
get_state_dict
可以處理任何由 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及這些平行化的任何組合所平行化的模組。get_state_dict
的主要功能有:1.) 回傳一個模型和最佳化器 state_dict,可以用不同數量的訓練器和/或不同的平行化方式重新分片。 2.) 隱藏特定於平行化的 state_dict API。 使用者不必呼叫這些 API。 3.) 健全性檢查結果 state_dict。結果 state dictionary 的鍵是標準 FQN (完全合格名稱)。 標準 FQN 是指基於參數在 nn.Module 階層中的位置的 FQN。 更具體地說,參數的標準 FQN 是當模組未被任何平行化方式分散時,
module.named_parameters()
或module.named_buffers()
回傳的 FQN。 由於最佳化器在內部使用參數 ID 來表示參數,因此在呼叫此 API 時,會將參數 ID 轉換為標準 FQN。get_state_dict
也可以處理未平行化的模組。 在這種情況下,get_state_dict
僅執行一個功能 - 將最佳化器參數 ID 轉換為標準 FQN。範例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- 參數
model (nn.Module) – 要對其進行建模的 nn.Module。
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用於最佳化
model
的最佳化器。submodules (deprecated) – Optional[Set[nn.Module]]: 僅回傳屬於子模組的模型參數。
options (StateDictOptions) – 控制應如何回傳模型 state_dict 和最佳化器 state_dict 的選項。 有關詳細資訊,請參閱 StateDictOptions。
- 傳回
Tuple
包含模型 state_dict 和最佳化器 state_dict。- 傳回類型
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source][source]¶
回傳
model
的模型 state_dict。有關詳細使用資訊,請參閱
get_state_dict
。- 參數
model (nn.Module) – 要對其進行建模的 nn.Module。
submodules (deprecated) – Optional[Set[nn.Module]]: 僅回傳屬於子模組的模型參數。
options (StateDictOptions) – 控制應如何回傳模型 state_dict 和最佳化器 state_dict 的選項。 有關詳細資訊,請參閱 StateDictOptions。
- 傳回
model
的 state_dict。- 傳回類型
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]¶
回傳最佳化器的合併 state_dict。
有關詳細使用資訊,請參閱
get_state_dict
。- 參數
model (nn.Module) – 要對其進行建模的 nn.Module。
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用於最佳化
model
的最佳化器。submodules (deprecated) – Optional[Set[nn.Module]]: 僅回傳屬於子模組的模型參數。
options (StateDictOptions) – 控制應如何回傳模型 state_dict 和最佳化器 state_dict 的選項。 有關詳細資訊,請參閱 StateDictOptions。
- 傳回
optimizers
的 state_dict。- 傳回類型
OptimizerStateType
- torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source][source]¶
載入模型 state_dict 以及優化器 state_dict。
此為
get_state_dict
的對應函數,用於將 state_dict 設定至模型和優化器。給定的model_state_dict
和optim_state_dict
不一定要由get_state_dict
傳回,但必須符合以下要求:1) 所有 FQN 都是如get_state_dict
中定義的正規 FQN,2) 如果 tensor 被分片,則它必須是 ShardedTensor 或 DTensor,3) 優化器 state_dict 不能包含參數 ID;鍵應為正規 FQN。- 參數
model (nn.Module) – 要對其進行建模的 nn.Module。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用於優化
model
的優化器。model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要載入的模型 state_dict。如果
model_state_dict
的鍵為 nn.Module,則該鍵是model
的子模組,並且值應為該子模組的 state_dict。載入 state_dict 時,子模組的前綴將附加到 state_dict。optim_state_dict (OptimizerStateType) – OptimizerStateType: 要載入的優化器 state_dict。
options (StateDictOptions) – 控制如何載入模型 state_dict 和優化器 state_dict 的選項。詳情請參閱 StateDictOptions。
- 傳回
missing_keys 是一個字串列表,包含模型 state_dict 中缺少的鍵。
unexpected_keys 是一個字串列表,包含模型 state_dict 中未預期的鍵。
- 傳回類型
NamedTuple
具有missing_keys
和unexpected_keys
欄位
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source][source]¶
載入模型 state_dict。
此為
get_model_state_dict
的對應函數,用於將 state_dict 設定至模型。有關詳細用法,請參閱set_state_dict
。- 參數
model (nn.Module) – 要對其進行建模的 nn.Module。
model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要載入的模型 state_dict。如果
model_state_dict
的鍵為 nn.Module,則該鍵是model
的子模組,並且值應為該子模組的 state_dict。載入 state_dict 時,子模組的前綴將附加到 state_dict。options (StateDictOptions) – 控制如何載入模型 state_dict 和優化器 state_dict 的選項。詳情請參閱 StateDictOptions。
- 傳回
missing_keys 是一個字串列表,包含缺少的鍵
unexpected_keys 是一個字串列表,包含未預期的鍵
- 傳回類型
NamedTuple
具有missing_keys
和unexpected_keys
欄位
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source][source]¶
載入優化器 state_dict。
此為
get_optimizer_state_dict
的對應函數,用於將 state_dict 設定至優化器。有關詳細用法,請參閱set_state_dict
。- 參數
model (nn.Module) – 要對其進行建模的 nn.Module。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用於優化
model
的優化器。optim_state_dict (OptimizerStateType) – OptimizerStateType: 要載入的優化器 state_dict。
options (StateDictOptions) – 控制如何載入模型 state_dict 和優化器 state_dict 的選項。詳情請參閱 StateDictOptions。
- 傳回
None
- 傳回類型
None
- class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False)[source][source]¶
此資料類別 (dataclass) 指定 get_state_dict/set_state_dict 的運作方式。
full_state_dict
: 如果設定為 True,則會收集傳回的 state_dict 中的所有 tensors。 傳回的 state_dict 中不會有 ShardedTensor 和 DTensor。cpu_offload
: 將所有 tensors 卸載到 CPU。 為了防止 CPU 記憶體不足 (OOM),如果full_state_dict
也為 true,則只有 rank0 會取得 state_dict,而所有其他 rank 將會取得空的 state_dict。ignore_frozen_params
: 如果值為 True,則傳回的 state_dict 不會包含任何凍結的參數 –requires_grad
為 False。 預設值為 False。keep_submodule_prefixes
(已棄用): 當submodules
不是 None 時,此選項表示是否保留 state_dict 鍵中的 submodule 前綴。 例如,如果 submodule 是module.pretrain
並且參數的完整 FQN 是pretrain.layer1.weight
。 當此選項為 True 時,傳回的 state_dict 中參數的鍵將為pretrain.layer1.weight
。 如果選項為 False,則鍵將為layer1.weight
。 請注意,如果keep_submodule_prefixes
為 False,則可能會發生衝突的 FQN,因此submodules
中應該只有一個 submodule。strict
: 當set_state_dict
呼叫 model.load_state_dict() 時的strict
選項。broadcast_from_rank0
: 當選項為 True 時,rank0 應該接收一個完整的 state_dict,並且會將 state_dict/optim_state_dict 中的 tensors 一個一個廣播到其他 rank。 其他 rank 將接收 tensors 並根據模型和最佳化器中的本地 shards 進行分片。 使用此選項時,必須將
full_state_dict
設定為 True。 此選項目前僅支援 DTensor,不支援舊的 ShardedTensor。
對於習慣於使用和共享 torch.save 格式模型的用戶,以下方法提供了離線實用程式,用於在格式之間進行轉換。
- torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source][source]¶
給定一個包含 DCP 檢查點的目錄,此函數會將其轉換為 Torch save 檔案。
- 參數
警告
為了避免 OOM,建議只在單一 rank 上執行此函數。
- torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source][source]¶
給定 torch save 檔案的位置,將其轉換為 DCP 檢查點。
- 參數
警告
為了避免 OOM,建議只在單一 rank 上執行此函數。
以下類別也可用於從 torch.save 格式線上載入和重新分片模型。
- class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source][source]¶
用於讀取 Torch Save 檔案的 StorageReader。此讀取器將在協調器 (coordinator) 排序 (rank) 上讀取整個檢查點 (checkpoint),然後將每個張量 (tensor) 廣播 (broadcast) 並分片 (shard) 到所有排序。
。注意:旨在與 DynamicMetaLoadPlanner 一起使用
警告
目前實作僅支援載入張量 (Tensors)。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
- class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]¶
DefaultLoadPlanner 的擴展,它根據傳入的狀態字典 (state dict) 建立一個新的 Metadata 物件,避免了從磁碟讀取元資料的需求。這在讀取沒有元資料檔案的格式 (如 Torch Save 檔案) 時非常有用。
。注意:旨在與 BroadcastingTorchSaveReader 一起使用
警告
目前實作僅支援載入張量 (Tensors)。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
提供以下實驗性介面,以改善生產環境中的可觀察性