捷徑

分散式檢查點

PyTorch/XLA SPMD 透過專用的 Planner 實例與 torch.distributed.checkpoint 函式庫相容。使用者能夠透過這個通用介面同步儲存和載入檢查點。

SPMDSavePlannerSPMDLoadPlanner (src) 類別使 saveload 函式能夠直接在 XLAShardedTensor 的分片上運作,從而在 SPMD 訓練中實現分散式檢查點的所有優點。

以下是同步分散式檢查點 API 的示範

import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

# Saving a state_dict
state_dict = {
    "model": model.state_dict(),
    "optim": optim.state_dict(),
}

dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
    planner=xc.SPMDSavePlanner(),
)
...

# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
    "model": model.state_dict(),
}

dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])

CheckpointManager

實驗性的 CheckpointManager 介面在 torch.distributed.checkpoint 函式之上提供了更高等級的 API,以實現以下幾個關鍵功能

  • 受管理的檢查點:由 CheckpointManager 取得的每個檢查點都由取得檢查點的步驟識別。所有追蹤的步驟都可以透過 CheckpointManager.all_steps 方法存取,並且可以使用 CheckpointManager.restore 還原任何追蹤的步驟。

  • 非同步檢查點:透過 CheckpointManager.save_async API 取得的檢查點會非同步寫入持久儲存空間,以在檢查點持續期間解除封鎖訓練。輸入的分片 state_dict 會先移動到 CPU,然後再將檢查點分派到背景執行緒。

  • 發生搶佔時自動檢查點:在 Cloud TPU 上,可以偵測到搶佔,並在程序終止之前取得檢查點。若要使用,請確保您的 TPU 是透過啟用 自動檢查點 的 QueuedResource 佈建的,並確保在建構 CheckpointManager 時設定 chkpt_on_preemption 參數 (預設啟用此選項)。

  • FSSpec 支援CheckpointManager 使用 fsspec 儲存後端,以直接將檢查點儲存到任何與 fsspec 相容的檔案系統,包括 GCS。

CheckpointManager 的範例用法如下

from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer

# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)

# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
    # Choose the highest step
    best_step = max(tracked_steps)
    # Before restoring the checkpoint, the optimizer state must be primed
    # to allow state to be loaded into it.
    prime_optimizer(optim)
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    chkpt_mgr.restore(best_step, state_dict)
    model.load_state_dict(state_dict['model'])
    optim.load_state_dict(state_dict['optim'])

# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
    ...
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    if chkpt_mgr.save_async(step, state_dict):
        print(f'Checkpoint taken at step {step}')

還原 Optimizer 狀態

在分散式檢查點中,state_dict 會就地載入,並且僅載入檢查點的必要分片。由於 optimizer 狀態是延遲建立的,因此在第一次呼叫 optimizer.step 之前,狀態不會出現,並且嘗試載入未啟動的 optimizer 將會失敗。

為此提供了實用方法 prime_optimizer:它透過將所有梯度設定為零並呼叫 optimizer.step 來執行虛擬訓練步驟。這是一種破壞性方法,會影響模型參數和 optimizer 狀態,因此僅應在還原之前呼叫。

處理序群組

若要使用 torch.distributed API (例如分散式檢查點),則需要處理序群組。在 SPMD 模式下,由於編譯器負責所有集合運算,因此不支援 xla 後端。

相反地,必須使用 CPU 處理序群組,例如 gloo。在 TPU 上,仍然支援 xla:// init_method 來探索主機 IP、全域世界大小和主機排名。以下是一個範例初始化

import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr

xr.use_spmd()

# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源