捷徑

具有 TorchScript 支援的分散式最佳化器

建立於:2021 年 4 月 26 日 | 最後更新:2024 年 12 月 02 日 | 最後驗證:2024 年 11 月 05 日

警告

TorchScript 不再積極開發。

在本指南中,您將學習

  • 具有 TorchScript 支援的分散式最佳化器的高階概念以及此功能帶來的優勢

  • 如何編寫啟用 TorchScript 支援的自訂分散式最佳化器

需求

什麼是分散式最佳化器?

DistributedOptimizer 接收遠端參數(RRef)的列表,並在參數所在的 workers 上在本機執行最佳化器,這通常與分散式 RPC/Autograd 一起使用以進行模型平行訓練。 它可以使用任何本地最佳化器演算法(torch.optim 中提供的預定義演算法或自訂定義的演算法)來將梯度應用於每個 worker。

什麼是具有 TorchScript 支援的分散式最佳化器?

分散式最佳化器廣泛用於分散式模型平行訓練,在一些常見的使用案例中,由於效能考量和資源利用率(或至少部分多執行緒,例如,參數伺服器託管模型和參數的一部分,新執行緒根據請求更新參數),訓練需要在多執行緒方式下完成,而不是多進程方式。 PyTorch 本身並不原生支援多執行緒訓練,因為它受到 Python 的全域解釋器鎖 (GIL) 的影響,但它可以利用 TorchScript 來擺脫 GIL 並以多執行緒方式執行模型。

對於關鍵的模型訓練工作負載,提高訓練效能是一個重要的課題。 研究人員通常希望通過圖形表示(即通過運算子融合)或實作自訂運算子核心來實作不同的最佳化策略,以加速訓練。

具有 TorchScript 支援的分散式最佳化器可以幫助擺脫 GIL,從而提高 PyTorch 在多執行緒環境中的訓練效能,它還解鎖了通過使用 TorchScript 提供的進階編譯器技術(即 CPU/GPU 融合)來進一步提高效能的潛力。

如何編寫具有 TorchScript 支援的自訂分散式最佳化器?

下面的程式碼顯示了如何給定現有的本地最佳化器實作編寫自訂分散式最佳化器,這解鎖了 TorchScript 的優勢,包括消除 GIL 和提高效能的機會。

假設您已經有一個目前在訓練期間使用的本地最佳化器,在本例中,我們將使用 quasi-hyperbolic momentum (QHM) 作為一個範例,來說明如何啟用 TorchScript 支援,請注意,它也適用於任何繼承自 torch.optim.Optimizer 的自訂最佳化器。

首先,我們需要將計算和狀態管理與最佳化器實作分開,這樣我們就可以提取計算部分並將其變成一個自由函數,這對 TorchScript 來說很友善。 它有兩個好處:1. 計算邏輯變得更容易檢查,它使我們能夠快速地將參數更新/計算部分轉變為 TorchScript,並利用 TorchScript IR 進行進一步的最佳化(運算子融合等)。 2. 分散式最佳化器底層使用不同的機制來獲取梯度和更新參數(我們分別儲存梯度,而不是在向後傳播期間直接填充 param.grad 欄位)。 分離計算使分散式最佳化器能夠在多執行緒模式下啟用最佳化器更新的可能性,因為它消除了對 param.grad 的可能競爭條件。

import torch
from torch import Tensor
from typing import List


def qhm_update(params: List[Tensor],
            dp_list: List[Tensor],
            momentum_buffer_list: List[Tensor],
            lr: float,
            nu: float,
            weight_decay: float,
            weight_decay_type: str,
            momentum: float):

    for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
        if weight_decay != 0:
            if weight_decay_type == "grad":
                d_p.add_(weight_decay, p)
            elif weight_decay_type == "direct":
                p.mul_(1.0 - lr * weight_decay)
            else:
                raise ValueError("Invalid weight decay type provided")

        momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)

        p.data.add_(-lr * nu, momentum_buffer)
        p.data.add_(-lr * (1.0 - nu), d_p)

接下來,我們將定義一個具有 TorchScript 相容性的分散式功能最佳化器,以管理最佳化器狀態並呼叫我們上面定義的 TorchScript 相容更新函數。 請注意,與普通自訂最佳化器相比,有一些約定不同:1. 我們不繼承 torch.optim.Optimizer,因為 TorchScript 不支援多型。 2. step 接收梯度列表,而不是損失閉包。

import torch
from torch import Tensor
from typing import List, Optional, Dict

# define this as a TorchScript class
@torch.jit.script
class FunctionalQHM(object):
    def __init__(self,
                params: List[Tensor],
                lr: float,
                momentum: float,
                nu: float,
                weight_decay: float = 0.0,
                weight_decay_type: str = "grad"):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if weight_decay_type not in ("grad", "direct"):
            raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))

        self.defaults = {
            "lr": lr,
            "momentum": momentum,
            "nu": nu,
            "weight_decay": weight_decay,
        }
        self.weight_decay_type = weight_decay_type

        # NOTE: we only have one param_group here and don't allow user to add additional
        # param group as it's not a common use case.
        self.param_group = {"params": params}

        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

    def step(self, gradients: List[Optional[Tensor]]):
        params = self.param_group['params']
        params_with_grad = []
        grads = []
        momentum_buffer_list: List[Tensor] = []

        if len(params) != len(gradients):
            raise ValueError(
                "the gradients passed in does not equal to the size of the parameters!"
                + f"Params length: {len(params)}. "
                + f"Gradients length: {len(gradients)}"
            )

        for param, gradient in zip(self.param_group['params'], gradients):
            if gradient is not None:
                params_with_grad.append(param)
                grads.append(gradient)
                state = self.state[param]
                state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)
                momentum_buffer_list.append(state['momentum_buffer'])

        # calls into the update function we just defined
        with torch.no_grad():
            qhm_update(params_with_grad,
                    grads,
                    momentum_buffer_list,
                    self.defaults['lr'],
                    self.defaults['nu'],
                    self.defaults['weight_decay'],
                    self.weight_decay_type,
                    self.defaults['momentum'])

最後,我們將新定義的分散式功能最佳化器註冊到 functional_optim_map 中。 這樣,DistributedOptimizer 將嘗試提取我們的自訂實作,而不是預定義的預設實作。

from torch.distributed.optim import DistributedOptimizer

DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM

現在,您可以像往常一樣在分散式訓練中使用 QHM 最佳化器,方法是將其傳遞給 DistributedOptimizer

...
remote_params_list = [...]
dist_optim = DistributedOptimizer(
    QHM, remote_params_list, *args, **kwargs
)

DistributedOptimizer 將自動在幕後將 QHM 最佳化器轉換為 FunctionalQHM,並啟用 TorchScript 支援。 這將解鎖由多執行緒訓練推動的效能,並為進一步的改進提供更多潛力(即 TorchScript 融合等)。

請注意,大多數 PyTorch 內建最佳化器已經使用這種方法來加速分散式訓練。 如果您看到關於某些最佳化器尚未轉換的警告,您可以按照本指南編寫自己的轉換。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得解答

檢視資源