快捷方式

分散式優化器

警告

使用 CUDA tensors 時,目前不支援分散式優化器

torch.distributed.optim 公開 DistributedOptimizer,它接受遠端參數清單 (RRef) 並在參數所在的 workers 上本地執行優化器。分散式優化器可以使用任何本地優化器 基底類別 來將梯度應用於每個 worker。

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source][source]

DistributedOptimizer 接受分散在各個 worker 上的參數的遠端參考 (remote references),並對每個參數在本地應用指定的 optimizer。

此類別使用 get_gradients() 來檢索特定參數的梯度。

同時調用 step() (無論來自相同或不同的客戶端) 將會在每個 worker 上序列化,因為每個 worker 的 optimizer 一次只能處理一組梯度。但是,無法保證完整的 forward-backward-optimizer 序列會一次執行於一個客戶端。這表示正在應用的梯度可能與給定 worker 上執行的最新 forward pass 不符。此外,也無法保證各個 worker 之間的順序。

DistributedOptimizer 預設情況下會啟用 TorchScript 來建立本地 optimizer,以便在多執行緒訓練 (例如分散式模型並行) 的情況下,optimizer 的更新不會受到 Python 全域直譯器鎖定 (GIL) 的阻礙。目前大多數的 optimizer 都有啟用此功能。您也可以依照 PyTorch 教學中的 秘訣 來為您自己的自訂 optimizer 啟用 TorchScript 支援。

參數
  • optimizer_class (optim.Optimizer) – 在每個 worker 上實例化的 optimizer 類別。

  • params_rref (list[RRef]) – 要優化的本地或遠端參數的 RRef 清單。

  • args – 傳遞給每個 worker 上 optimizer 建構子的引數。

  • kwargs – 傳遞給每個 worker 上 optimizer 建構子的引數。

範例:
>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>>
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>>
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD,
>>>      [rref1, rref2],
>>>      lr=0.05,
>>>   )
>>>   dist_optim.step(context_id)
step(context_id)[source][source]

執行單一最佳化步驟。

這會在包含要優化的參數的每個 worker 上呼叫 torch.optim.Optimizer.step(),並且會阻塞直到所有 worker 都傳回。提供的 context_id 將用於檢索包含應套用於參數的梯度的對應 context

參數

context_id – 應該執行 optimizer 步驟的 autograd context id。

class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[source][source]

封裝任意的 torch.optim.Optimizer 並執行 post-local SGD。此 optimizer 在每個步驟執行本地 optimizer。在 warm-up 階段之後,它會在應用本地 optimizer 之後定期平均參數。

參數
  • optim (Optimizer) – 本地 optimizer。

  • averager (ModelAverager) – 用於執行 post-localSGD 演算法的模型平均器實例。

範例

>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>>   PostLocalSGDState,
>>>   post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>>    module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>>     optim=local_optim,
>>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>>    opt.zero_grad()
>>>    loss = loss_fn(output, labels)
>>>    loss.backward()
>>>    opt.step()
load_state_dict(state_dict)[source][source]

這與 torch.optim.Optimizerload_state_dict() 相同,但也會將模型平均器的 step 值還原為 state_dict 中儲存的值。

如果 state_dict 中沒有 "step" 條目,則會發出警告並將模型平均器的 step 初始化為 0。

state_dict()[source][source]

這與 torch.optim.Optimizerstate_dict() 相同,但增加了一個額外的條目來記錄模型平均器的步數到檢查點,以確保重新載入不會再次導致不必要的熱身。

step()[原始碼][原始碼]

執行單一優化步驟 (參數更新)。

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[原始碼][原始碼]

封裝一個任意的 optim.Optimizer 並在群組中的各個 rank 間對其狀態進行分片。

此共享依照 ZeRO 的描述進行。

每個 rank 中的本地優化器實例僅負責更新約 1 / world_size 的參數,因此只需要保留 1 / world_size 的優化器狀態。在本地更新參數後,每個 rank 將其參數廣播到所有其他 peer,以保持所有模型副本處於相同的狀態。ZeroRedundancyOptimizer 可以與 torch.nn.parallel.DistributedDataParallel 結合使用,以減少每個 rank 的峰值記憶體消耗。

ZeroRedundancyOptimizer 使用經過排序的貪婪演算法來封裝每個 rank 中的多個參數。每個參數都屬於單個 rank,並且不在 rank 之間分割。此分割是任意的,並且可能與參數註冊或使用順序不符。

參數

params (Iterable) – 一個 Iterabletorch.Tensor s 或 dict s,提供將在各個 rank 間分片的所有參數。

關鍵字引數
  • optimizer_class (torch.nn.Optimizer) – 本地優化器的類別。

  • process_group (ProcessGroup, 選用) – torch.distributed ProcessGroup (預設值:由 torch.distributed.init_process_group() 初始化的 dist.group.WORLD)。

  • parameters_as_bucket_view (bool, 選用) – 如果為 True,參數將被封裝到 bucket 中以加速通訊,並且 param.data 欄位將指向不同偏移量的 bucket 檢視;如果為 False,則每個單獨的參數會單獨傳輸,並且每個 params.data 保持不變 (預設值:False)。

  • overlap_with_ddp (bool, 選用) – 如果為 True,則 step()DistributedDataParallel 的梯度同步重疊;這需要 (1) optimizer_class 引數的函數式優化器,或者具有函數式等效項的優化器,以及 (2) 註冊一個由 ddp_zero_hook.py 中函數之一建構的 DDP 通訊 hook;參數被封裝到與 DistributedDataParallel 中匹配的 bucket 中,這意味著 parameters_as_bucket_view 引數被忽略。如果為 False,則 step() 在反向傳播之後不相交地執行(按照正常)。(預設值:False

  • **defaults – 任何尾隨引數,這些引數會轉發到本地優化器。

範例

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

目前,ZeroRedundancyOptimizer 要求所有傳入的參數都是相同的密集型。

警告

如果您傳遞 overlap_with_ddp=True,請注意以下事項:鑑於目前實現的將 DistributedDataParallelZeroRedundancyOptimizer 重疊的方式,前兩個或三個訓練迭代不會在優化器步驟中執行參數更新,具體取決於 static_graph=Falsestatic_graph=True。這是因為它需要有關 DistributedDataParallel 使用的梯度 bucket 策略的資訊,如果 static_graph=False,則直到第二次正向傳播才會最終確定,如果 static_graph=True,則直到第三次正向傳播才會最終確定。要對此進行調整,一種選擇是預先添加虛擬輸入。

警告

ZeroRedundancyOptimizer 是實驗性的,可能會發生更改。

add_param_group(param_group)[原始碼][原始碼]

將一個參數群組添加到 Optimizerparam_groups 中。

當微調預訓練網路時,這會很有用,因為可以將凍結層變為可訓練層,並在訓練過程中將其添加到 Optimizer 中。

參數

param_group (dict) – 指定要優化的參數和群組特定的優化選項。

警告

此方法處理更新所有分割區上的 shards,但需要在所有 ranks 上呼叫。在部分 ranks 上呼叫此方法會導致訓練停止,因為根據管理的參數呼叫了通信原語,並期望所有 ranks 都參與相同的參數集。

consolidate_state_dict(to=0)[原始碼][原始碼]

state_dict 的列表 (每個 rank 一個) 合併到目標 rank 上。

參數

to (int) – 接收優化器狀態的 rank (預設值: 0)。

引發

RuntimeError – 如果 overlap_with_ddp=True 且此方法在 ZeroRedundancyOptimizer 實例完全初始化之前呼叫,這會在 DistributedDataParallel 梯度 buckets 重新構建後發生。

警告

需要在所有 ranks 上呼叫此方法。

property join_device: device

傳回預設裝置。

join_hook(**kwargs)[原始碼][原始碼]

傳回 ZeRO join hook。

它通過 shadow 優化器步驟中的集合通信,來實現對不均勻輸入的訓練。

在呼叫此 hook 之前,必須正確設定梯度。

參數

kwargs (dict) – 一個 dict,包含任何關鍵字參數,以在執行時修改 join hook 的行為;所有共享相同 join context manager 的 Joinable 實例都會轉發 kwargs 的相同值。

此 hook 不支援任何關鍵字參數;即 kwargs 未使用。

property join_process_group: Any

傳回 process group。

load_state_dict(state_dict)[原始碼][原始碼]

從輸入 state_dict 載入與給定 rank 相關的狀態,並根據需要更新本地優化器。

參數

state_dict (dict) – 優化器狀態;應該是從呼叫 state_dict() 傳回的物件。

引發

RuntimeError – 如果 overlap_with_ddp=True 且此方法在 ZeroRedundancyOptimizer 實例完全初始化之前呼叫,這會在 DistributedDataParallel 梯度 buckets 重新構建後發生。

state_dict()[原始碼][原始碼]

傳回此 rank 已知的最後一個全域優化器狀態。

引發

如果 overlap_with_ddp=True,且此方法在 ZeroRedundancyOptimizer 實例完全初始化之前被呼叫(這會在 DistributedDataParallel 梯度 buckets 被重建後發生);或者如果此方法在未先呼叫 consolidate_state_dict() 的情況下被呼叫,則會引發 RuntimeError

回傳型別

Dict[str, Any]

step(closure=None, **kwargs)[source][source]

執行單一最佳化器步驟,並在所有 ranks 之間同步參數。

參數

closure (Callable) – 一個重新評估模型並回傳損失的 closure;對於大多數最佳化器是可選的。

回傳

根據底層本地最佳化器,可能回傳損失。

回傳型別

Optional[float]

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題解答

檢視資源