分散式檢查點¶
PyTorch/XLA SPMD 透過專用的 Planner
實例與 torch.distributed.checkpoint 函式庫相容。使用者能夠透過這個通用介面同步儲存和載入檢查點。
SPMDSavePlanner
和 SPMDLoadPlanner
(src) 類別使 save
和 load
函式能夠直接在 XLAShardedTensor
的分片上運作,從而在 SPMD 訓練中實現分散式檢查點的所有優點。
以下是同步分散式檢查點 API 的示範
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc
# Saving a state_dict
state_dict = {
"model": model.state_dict(),
"optim": optim.state_dict(),
}
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
)
...
# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
"model": model.state_dict(),
}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])
CheckpointManager¶
實驗性的 CheckpointManager 介面在 torch.distributed.checkpoint
函式之上提供了更高等級的 API,以實現以下幾個關鍵功能
受管理的檢查點:由
CheckpointManager
取得的每個檢查點都由取得檢查點的步驟識別。所有追蹤的步驟都可以透過CheckpointManager.all_steps
方法存取,並且可以使用CheckpointManager.restore
還原任何追蹤的步驟。非同步檢查點:透過
CheckpointManager.save_async
API 取得的檢查點會非同步寫入持久儲存空間,以在檢查點持續期間解除封鎖訓練。輸入的分片 state_dict 會先移動到 CPU,然後再將檢查點分派到背景執行緒。發生搶佔時自動檢查點:在 Cloud TPU 上,可以偵測到搶佔,並在程序終止之前取得檢查點。若要使用,請確保您的 TPU 是透過啟用 自動檢查點 的 QueuedResource 佈建的,並確保在建構 CheckpointManager 時設定
chkpt_on_preemption
參數 (預設啟用此選項)。FSSpec 支援:
CheckpointManager
使用 fsspec 儲存後端,以直接將檢查點儲存到任何與 fsspec 相容的檔案系統,包括 GCS。
CheckpointManager 的範例用法如下
from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
# Choose the highest step
best_step = max(tracked_steps)
# Before restoring the checkpoint, the optimizer state must be primed
# to allow state to be loaded into it.
prime_optimizer(optim)
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
chkpt_mgr.restore(best_step, state_dict)
model.load_state_dict(state_dict['model'])
optim.load_state_dict(state_dict['optim'])
# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
...
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
if chkpt_mgr.save_async(step, state_dict):
print(f'Checkpoint taken at step {step}')
還原 Optimizer 狀態¶
在分散式檢查點中,state_dict 會就地載入,並且僅載入檢查點的必要分片。由於 optimizer 狀態是延遲建立的,因此在第一次呼叫 optimizer.step
之前,狀態不會出現,並且嘗試載入未啟動的 optimizer 將會失敗。
為此提供了實用方法 prime_optimizer
:它透過將所有梯度設定為零並呼叫 optimizer.step
來執行虛擬訓練步驟。這是一種破壞性方法,會影響模型參數和 optimizer 狀態,因此僅應在還原之前呼叫。
處理序群組¶
若要使用 torch.distributed
API (例如分散式檢查點),則需要處理序群組。在 SPMD 模式下,由於編譯器負責所有集合運算,因此不支援 xla
後端。
相反地,必須使用 CPU 處理序群組,例如 gloo
。在 TPU 上,仍然支援 xla://
init_method 來探索主機 IP、全域世界大小和主機排名。以下是一個範例初始化
import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
xr.use_spmd()
# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')