捷徑

torch.optim

torch.optim 是一個實現各種優化演算法的套件。

最常用的方法已經支援,而且介面也夠通用,因此未來可以輕鬆整合更複雜的方法。

如何使用優化器

要使用 torch.optim,您必須建構一個優化器物件,該物件將持有目前狀態,並根據計算出的梯度更新參數。

建構它

要建構一個 Optimizer,您必須提供一個包含要優化的參數(所有參數都應該是 Parameter)或具名參數((str, Parameter) 的 tuple)的可迭代物件。然後,您可以指定優化器特定的選項,例如學習率、權重衰減等。

範例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

具名參數範例

optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)

每個參數的選項

Optimizer 也支援指定每個參數的選項。若要這樣做,不要傳遞 Variable 的可迭代物件,而是傳遞 dict 的可迭代物件。它們中的每一個都會定義一個單獨的參數組,並且應該包含一個 params 鍵,其中包含屬於它的參數列表。其他鍵應該與優化器接受的關鍵字參數相符,並將用作該組的優化選項。

例如,當想要指定每層的學習率時,這非常有用

optim.SGD([
                {'params': model.base.parameters(), 'lr': 1e-2},
                {'params': model.classifier.parameters()}
            ], lr=1e-3, momentum=0.9)

optim.SGD([
                {'params': model.base.named_parameters(), 'lr': 1e-2},
                {'params': model.classifier.named_parameters()}
            ], lr=1e-3, momentum=0.9)

這表示 model.base 的參數將使用 1e-2 的學習率,而 model.classifier 的參數將堅持使用預設的學習率 1e-3。最後,所有參數都將使用 0.9 的動量。

注意

您仍然可以將選項作為關鍵字參數傳遞。它們將用作預設值,用於未覆蓋它們的群組。當您只想更改單個選項,同時保持所有其他選項在參數群組之間一致時,這很有用。

另請考慮以下與參數的不同懲罰相關的範例。請記住,parameters() 傳回一個可迭代物件,其中包含所有可學習的參數,包括可能更喜歡不同懲罰的偏差和其他參數。為了處理這個問題,可以為每個參數組指定單獨的懲罰權重

bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]

optim.SGD([
                {'params': others},
                {'params': bias_params, 'weight_decay': 0}
            ], weight_decay=1e-2, lr=1e-2)

以這種方式,偏差項與非偏差項隔離,並且專門為偏差項設定 0weight_decay,以避免對該組進行任何懲罰。

執行優化步驟

所有優化器都實現了一個 step() 方法,用於更新參數。它可以使用兩種方式

optimizer.step()

這是大多數優化器支援的簡化版本。一旦使用例如 backward() 計算出梯度,就可以調用該函數。

範例

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)

某些優化演算法(例如共軛梯度和 LBFGS)需要多次重新評估該函數,因此您必須傳入一個閉包,使其能夠重新計算您的模型。閉包應該清除梯度,計算損失並返回它。

範例

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

基底類別

class torch.optim.Optimizer(params, defaults)[source][source]

所有優化器的基底類別。

警告

參數需要指定為集合,這些集合具有確定性的順序,並且在多次執行之間保持一致。不滿足這些屬性的物件範例是集合以及字典值的迭代器。

參數
  • params (iterable) – torch.Tensor s 或 dict s 的可迭代物件。指定應該優化的 Tensor。

  • defaults (Dict[str, Any]) – (dict): 包含優化選項的預設值的 dict(當參數群組未指定它們時使用)。

Optimizer.add_param_group

將參數群組添加到 Optimizerparam_groups

Optimizer.load_state_dict

載入優化器狀態。

Optimizer.register_load_state_dict_pre_hook

註冊一個 load_state_dict pre-hook,它將在調用 load_state_dict() 之前調用。它應該具有以下簽名::。

Optimizer.register_load_state_dict_post_hook

註冊一個 load_state_dict 後置鉤子 (post-hook),它會在 load_state_dict() 呼叫後被呼叫。它應該具有以下簽名::

Optimizer.state_dict

dict 的形式傳回最佳化器的狀態。

Optimizer.register_state_dict_pre_hook

註冊一個 state dict 前置鉤子 (pre-hook),它會在 state_dict() 呼叫之前被呼叫。

Optimizer.register_state_dict_post_hook

註冊一個 state dict 後置鉤子 (post-hook),它會在 state_dict() 呼叫之後被呼叫。

Optimizer.step

執行單一步最佳化以更新參數。

Optimizer.register_step_pre_hook

註冊一個最佳化器 step 前置鉤子 (pre-hook),它會在最佳化器 step 之前被呼叫。

Optimizer.register_step_post_hook

註冊一個最佳化器 step 後置鉤子 (post-hook),它會在最佳化器 step 之後被呼叫。

Optimizer.zero_grad

重置所有最佳化的 torch.Tensor 的梯度。

演算法

Adadelta

實作 Adadelta 演算法。

Adafactor

實作 Adafactor 演算法。

Adagrad

實作 Adagrad 演算法。

Adam

實作 Adam 演算法。

AdamW

實作 AdamW 演算法。

SparseAdam

SparseAdam 實作 Adam 演算法的遮罩版本,適用於稀疏梯度。

Adamax

實作 Adamax 演算法 (基於無窮範數的 Adam 變體)。

ASGD

實作平均隨機梯度下降法。

LBFGS

實作 L-BFGS 演算法。

NAdam

實作 NAdam 演算法。

RAdam

實作 RAdam 演算法。

RMSprop

實作 RMSprop 演算法。

Rprop

實作彈性反向傳播演算法。

SGD

實作隨機梯度下降法 (可選地帶有動量)。

我們的許多演算法都有針對效能、可讀性和/或通用性優化的各種實作,因此如果使用者沒有指定任何特定的實作,我們會嘗試預設使用目前裝置上通常最快的實作。

我們有 3 個主要的實作類別:for-loop、foreach (multi-tensor) 和 fused。 最直接的實作是透過具有大量計算區塊的參數進行 for 迴圈。 For-looping 通常比我們的 foreach 實作慢,後者將參數組合到一個 multi-tensor 中,並一次執行大量的計算區塊,從而節省了許多連續的 kernel 呼叫。 我們的一些最佳化器甚至具有更快的 fused 實作,這些實作將大量的計算區塊融合到一個 kernel 中。 我們可以將 foreach 實作視為水平融合,將 fused 實作視為在其之上的垂直融合。

一般來說,3 種實作的效能順序是 fused > foreach > for-loop。 因此,在適用的情況下,我們預設使用 foreach 而不是 for-loop。 適用表示 foreach 實作可用、使用者未指定任何特定於實作的 kwargs(例如,fused、foreach、differentiable),並且所有張量都是原生的。 請注意,雖然 fused 應該比 foreach 更快,但這些實作比較新,我們希望在全面切換之前給予它們更多的燒機時間。 我們在下面的第二個表中總結了每個實作的穩定性狀態,歡迎您嘗試它們!

下表顯示了每個演算法的可用和預設實作

演算法

預設

有 foreach 嗎?

有 fused 嗎?

Adadelta

foreach

Adafactor

for-loop

Adagrad

foreach

是 (僅限 cpu)

Adam

foreach

AdamW

foreach

SparseAdam

for-loop

Adamax

foreach

ASGD

foreach

LBFGS

for-loop

NAdam

foreach

RAdam

foreach

RMSprop

foreach

Rprop

foreach

SGD

foreach

下表顯示了 fused 實作的穩定性狀態

演算法

CPU

CUDA

MPS

Adadelta

不支援

不支援

不支援

Adafactor

不支援

不支援

不支援

Adagrad

beta

不支援

不支援

Adam

beta

穩定

beta

AdamW

beta

穩定

beta

SparseAdam

不支援

不支援

不支援

Adamax

不支援

不支援

不支援

ASGD

不支援

不支援

不支援

LBFGS

不支援

不支援

不支援

NAdam

不支援

不支援

不支援

RAdam

不支援

不支援

不支援

RMSprop

不支援

不支援

不支援

Rprop

不支援

不支援

不支援

SGD

beta

beta

beta

如何調整學習率

torch.optim.lr_scheduler.LRScheduler 提供了幾種根據 epoch 數量調整學習率的方法。 torch.optim.lr_scheduler.ReduceLROnPlateau 允許根據一些驗證量測動態降低學習率。

學習率排程應該在最佳化器的更新之後應用;例如,您應該這樣編寫您的程式碼

範例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

大多數學習率排程器可以背靠背呼叫 (也稱為鏈式排程器)。 這樣做的結果是,每個排程器都按順序應用於前一個排程器獲得的學習率。

範例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

在文件中的許多地方,我們將使用以下範本來引用排程器演算法。

>>> scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

警告

在 PyTorch 1.1.0 之前,學習率排程器預計在最佳化器的更新之前被呼叫; 1.1.0 以 BC 中斷的方式改變了這種行為。 如果您在最佳化器的更新 (呼叫 optimizer.step()) 之前使用學習率排程器 (呼叫 scheduler.step()),這將跳過學習率排程的第一個值。 如果您在升級到 PyTorch 1.1.0 後無法重現結果,請檢查您是否在錯誤的時間呼叫 scheduler.step()

lr_scheduler.LRScheduler

在最佳化期間調整學習率。

lr_scheduler.LambdaLR

設定初始學習率。

lr_scheduler.MultiplicativeLR

將每個參數群組的學習率乘以指定函數中給定的因子。

lr_scheduler.StepLR

每隔 step_size 個 epochs,將每個參數群組的學習率衰減 gamma 倍。

lr_scheduler.MultiStepLR

一旦 epoch 數達到其中一個里程碑,就將每個參數群組的學習率衰減 gamma 倍。

lr_scheduler.ConstantLR

將每個參數群組的學習率乘以一個小的常數因子。

lr_scheduler.LinearLR

通過線性改變小的乘法因子來衰減每個參數組的學習率。

lr_scheduler.ExponentialLR

每個 epoch 將每個參數組的學習率衰減 gamma 倍。

lr_scheduler.PolynomialLR

使用給定的 total_iters 中的多項式函數衰減每個參數組的學習率。

lr_scheduler.CosineAnnealingLR

使用餘弦退火排程設定每個參數群組的學習率。

lr_scheduler.ChainedScheduler

將一系列學習率排程器串聯起來。

lr_scheduler.SequentialLR

包含一個在優化過程中預期會依序呼叫的排程器列表。

lr_scheduler.ReduceLROnPlateau

當某個指標停止改善時,降低學習率。

lr_scheduler.CyclicLR

根據循環學習率策略 (CLR) 設定每個參數群組的學習率。

lr_scheduler.OneCycleLR

根據 1cycle 學習率策略設定每個參數群組的學習率。

lr_scheduler.CosineAnnealingWarmRestarts

使用餘弦退火排程設定每個參數群組的學習率。

如何使用具名參數載入 optimizer 狀態字典

如果存在,函數 load_state_dict() 會儲存從載入的狀態字典中取得的可選 param_names 內容。但是,載入 optimizer 狀態的過程不會受到影響,因為參數的順序對於維持相容性非常重要 (以防順序不同)。若要使用從載入的狀態字典中載入的參數名稱,需要根據所需的行為實作一個自訂的 register_load_state_dict_pre_hook

這可能很有用,例如,當模型架構發生變化,但權重和 optimizer 狀態需要保持不變時。以下範例示範如何實作此自訂。

範例

class OneLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 4)

    def forward(self, x):
        return self.fc(x)

model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

假設 model 實作了一個專家模型 (MoE),並且我們想要複製它並恢復對兩個專家的訓練,兩者的初始化方式都與 fc 層相同。對於以下 model2,我們建立兩個與 fc 相同的層,並透過將 model 的模型權重和 optimizer 狀態載入到 model2fc1fc2 中來恢復訓練(並相應地調整它們)

class TwoLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(3, 4)

    def forward(self, x):
        return (self.fc1(x) + self.fc2(x)) / 2

model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

若要使用先前 optimizer 的狀態字典載入 optimizer2 的狀態字典,以便 fc1fc2 都以 fc optimizer 狀態的副本初始化 (以便從 fc 恢復每層的訓練),我們可以使用以下 hook

def adapt_state_dict_ids(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc1.weight': 'fc.weight',
        'fc1.bias': 'fc.bias',
        'fc2.weight': 'fc.weight',
        'fc2.bias': 'fc.bias'
    }
    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
        id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
        # Copy the state of the corresponding parameter
        if id_in_loaded in state_dict['state']:
            adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

這確保了在模型載入期間將使用針對 model2 層的正確狀態調整後的 state_dict。請注意,此程式碼專為此範例設計 (例如,假設單一參數群組),其他情況可能需要不同的調整。

以下範例示範了當模型結構發生變化時,如何處理載入的 state dict 中缺少的參數。Model_bypass 新增了一個新的 bypass 層,該層不存在於原始 Model1 中。為了恢復訓練,使用自訂的 adapt_state_dict_missing_param hook 來調整 optimizer 的 state_dict,確保現有參數被正確映射,而缺少的參數 (如 bypass 層) 保持不變 (如本範例中初始化)。儘管模型發生變化,此方法也能夠順利載入和恢復 optimizer 狀態。新的 bypass 層將從頭開始訓練

class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x) + x


model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

class Model_bypass(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)
        self.bypass = nn.Linear(5, 5, bias=False)
        torch.nn.init.eye_(self.bypass.weight)

    def forward(self, x):
        return self.fc(x) + self.bypass(x)

model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

def adapt_state_dict_missing_param(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc.weight': 'fc.weight',
        'fc.bias': 'fc.bias',
        'bypass.weight': None,
    }

    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        if name_in_loaded in state_dict['param_groups'][0]['param_names']:
            index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
            id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

作為第三個範例,這個 hook 可以用來根據參數的名稱載入狀態,而不是根據參數的順序載入狀態 (預設方法)。

def names_matching(optimizer, state_dict):
    assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
    adapted_state_dict = deepcopy(optimizer.state_dict())
    for g_ind in range(len(state_dict['param_groups'])):
        assert len(state_dict['param_groups'][g_ind]['params']) == len(
            optimizer.state_dict()['param_groups'][g_ind]['params'])

        for k, v in state_dict['param_groups'][g_ind].items():
            if k not in ['params', 'param_names']:
                adapted_state_dict['param_groups'][g_ind][k] = v

        for param_id, param_name in zip(
                optimizer.state_dict()['param_groups'][g_ind]['params'],
                optimizer.state_dict()['param_groups'][g_ind]['param_names']):
            index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
            id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

權重平均 (SWA 和 EMA)

torch.optim.swa_utils.AveragedModel 實作了隨機權重平均 (SWA) 和指數移動平均 (EMA),torch.optim.swa_utils.SWALR 實作了 SWA 學習率排程器,而 torch.optim.swa_utils.update_bn() 是一個公用程式函數,用於在訓練結束時更新 SWA/EMA 批次正規化統計資料。

SWA 已在 平均權重會產生更廣泛的 Optima 和更好的泛化效果 中提出。

EMA 是一種廣為人知的技術,透過減少所需的權重更新次數來縮短訓練時間。它是 Polyak 平均 的一種變體,但使用指數權重而不是跨迭代的相等權重。

建構平均模型

AveragedModel 類別用於計算 SWA 或 EMA 模型的權重。

您可以透過執行以下操作來建立 SWA 平均模型

>>> averaged_model = AveragedModel(model)

EMA 模型透過指定 multi_avg_fn 參數來建構,如下所示

>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))

Decay 是一個介於 0 和 1 之間的參數,用於控制平均參數衰減的速度。如果沒有提供給 torch.optim.swa_utils.get_ema_multi_avg_fn(),則預設值為 0.999。 Decay 值應接近 1.0,因為較小的值可能會導致最佳化收斂問題。

torch.optim.swa_utils.get_ema_multi_avg_fn() 傳回一個將以下 EMA 方程式套用至權重的函數

Wt+1EMA=αWtEMA+(1α)WtmodelW^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t

其中 alpha 為 EMA 衰減率。

這裡的 model 可以是任意的 torch.nn.Module 物件。averaged_model 將會追蹤 model 參數的滾動平均值。 為了更新這些平均值,您應該在 optimizer.step() 之後使用 update_parameters() 函數

>>> averaged_model.update_parameters(model)

對於 SWA 和 EMA,此呼叫通常在優化器 step() 之後立即完成。在 SWA 的情況下,通常會在訓練開始時跳過一些步驟。

自訂平均策略

預設情況下,torch.optim.swa_utils.AveragedModel 計算您提供的參數的滾動等權平均值,但您也可以使用具有 avg_fnmulti_avg_fn 參數的自訂平均函數

  • avg_fn 允許定義一個作用於每個參數元組(平均參數,模型參數)的函數,並且應該返回新的平均參數。

  • multi_avg_fn 允許定義更有效率的操作,同時作用於參數列表的元組(平均參數列表,模型參數列表),例如使用 torch._foreach* 函數。 此函數必須就地更新平均參數。

在下面的範例中,ema_model 使用 avg_fn 參數計算指數移動平均

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

在下面的範例中,ema_model 使用更有效率的 multi_avg_fn 參數計算指數移動平均

>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))

SWA 學習率排程

通常,在 SWA 中,學習率設定為高的常數值。SWALR 是一個學習率排程器,它將學習率退火到一個固定值,然後保持不變。 例如,以下程式碼建立一個排程器,該排程器在每個參數群組中於 5 個 epoch 內將學習率從其初始值線性退火到 0.05

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您也可以使用餘弦退火到一個固定值,而不是線性退火,方法是設定 anneal_strategy="cos"

注意批次正規化

update_bn() 是一個實用函數,允許在訓練結束時計算給定資料載入器 loader 上 SWA 模型的批次正規化統計量

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()swa_model 應用於資料載入器中的每個元素,並計算模型中每個批次正規化層的啟用統計量。

警告

update_bn() 假設資料載入器 loader 中的每個批次都是張量或張量列表,其中第一個元素是網路 swa_model 應用的張量。 如果您的資料載入器具有不同的結構,您可以通過對資料集的每個元素使用 swa_model 進行前向傳遞來更新 swa_model 的批次正規化統計量。

整合在一起:SWA

在下面的範例中,swa_model 是累積權重平均值的 SWA 模型。 我們總共訓練模型 300 個 epoch,並且我們切換到 SWA 學習率排程,並在 epoch 160 開始收集參數的 SWA 平均值

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

整合在一起:EMA

在下面的範例中,ema_model 是一個 EMA 模型,它使用 0.999 的衰減率來累計權重的指數衰減平均值。 我們總共訓練模型 300 個 epoch,並立即開始收集 EMA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>>             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>           ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)

swa_utils.AveragedModel

實現隨機權重平均 (Stochastic Weight Averaging, SWA) 和指數移動平均 (Exponential Moving Average, EMA) 的平均模型。

swa_utils.SWALR

將每個參數群組中的學習率退火到一個固定值。

torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source][source]

獲取在多個參數上應用指數移動平均 (EMA) 的函數。

torch.optim.swa_utils.update_bn(loader, model, device=None)[source][source]

更新模型中的 BatchNorm running_mean 和 running_var 緩衝區。

它在 loader 中的數據上執行一次遍歷,以估計模型中 BatchNorm 層的激活統計信息。

參數
  • loader (torch.utils.data.DataLoader) – 用於計算激活統計信息的資料集載入器。每個數據批次應該是一個張量,或者是一個列表/元組,其第一個元素是一個包含數據的張量。

  • model (torch.nn.Module) – 我們尋求更新 BatchNorm 統計信息的模型。

  • device (torch.device, optional) – 如果設定,資料將在傳遞到 model 之前傳輸到 device

範例

>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)

注意

update_bn 工具假設 loader 中的每個資料批次都是一個張量,或者是一個張量列表或元組;在後一種情況下,假設應該在與數據批次對應的列表或元組的第一個元素上調用 model.forward()

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

獲取初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並解答您的問題

檢視資源