torch.optim¶
torch.optim
是一個實現各種優化演算法的套件。
最常用的方法已經支援,而且介面也夠通用,因此未來可以輕鬆整合更複雜的方法。
如何使用優化器¶
要使用 torch.optim
,您必須建構一個優化器物件,該物件將持有目前狀態,並根據計算出的梯度更新參數。
建構它¶
要建構一個 Optimizer
,您必須提供一個包含要優化的參數(所有參數都應該是 Parameter
)或具名參數((str, Parameter
) 的 tuple)的可迭代物件。然後,您可以指定優化器特定的選項,例如學習率、權重衰減等。
範例
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
具名參數範例
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)
每個參數的選項¶
Optimizer
也支援指定每個參數的選項。若要這樣做,不要傳遞 Variable
的可迭代物件,而是傳遞 dict
的可迭代物件。它們中的每一個都會定義一個單獨的參數組,並且應該包含一個 params
鍵,其中包含屬於它的參數列表。其他鍵應該與優化器接受的關鍵字參數相符,並將用作該組的優化選項。
例如,當想要指定每層的學習率時,這非常有用
optim.SGD([
{'params': model.base.parameters(), 'lr': 1e-2},
{'params': model.classifier.parameters()}
], lr=1e-3, momentum=0.9)
optim.SGD([
{'params': model.base.named_parameters(), 'lr': 1e-2},
{'params': model.classifier.named_parameters()}
], lr=1e-3, momentum=0.9)
這表示 model.base
的參數將使用 1e-2
的學習率,而 model.classifier
的參數將堅持使用預設的學習率 1e-3
。最後,所有參數都將使用 0.9
的動量。
注意
您仍然可以將選項作為關鍵字參數傳遞。它們將用作預設值,用於未覆蓋它們的群組。當您只想更改單個選項,同時保持所有其他選項在參數群組之間一致時,這很有用。
另請考慮以下與參數的不同懲罰相關的範例。請記住,parameters()
傳回一個可迭代物件,其中包含所有可學習的參數,包括可能更喜歡不同懲罰的偏差和其他參數。為了處理這個問題,可以為每個參數組指定單獨的懲罰權重
bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]
optim.SGD([
{'params': others},
{'params': bias_params, 'weight_decay': 0}
], weight_decay=1e-2, lr=1e-2)
以這種方式,偏差項與非偏差項隔離,並且專門為偏差項設定 0
的 weight_decay
,以避免對該組進行任何懲罰。
執行優化步驟¶
所有優化器都實現了一個 step()
方法,用於更新參數。它可以使用兩種方式
optimizer.step()
¶
這是大多數優化器支援的簡化版本。一旦使用例如 backward()
計算出梯度,就可以調用該函數。
範例
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.step(closure)
¶
某些優化演算法(例如共軛梯度和 LBFGS)需要多次重新評估該函數,因此您必須傳入一個閉包,使其能夠重新計算您的模型。閉包應該清除梯度,計算損失並返回它。
範例
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
基底類別¶
- class torch.optim.Optimizer(params, defaults)[source][source]¶
所有優化器的基底類別。
警告
參數需要指定為集合,這些集合具有確定性的順序,並且在多次執行之間保持一致。不滿足這些屬性的物件範例是集合以及字典值的迭代器。
- 參數
params (iterable) –
torch.Tensor
s 或dict
s 的可迭代物件。指定應該優化的 Tensor。defaults (Dict[str, Any]) – (dict): 包含優化選項的預設值的 dict(當參數群組未指定它們時使用)。
將參數群組添加到 |
|
載入優化器狀態。 |
|
註冊一個 load_state_dict pre-hook,它將在調用 |
|
註冊一個 load_state_dict 後置鉤子 (post-hook),它會在 |
|
以 |
|
註冊一個 state dict 前置鉤子 (pre-hook),它會在 |
|
註冊一個 state dict 後置鉤子 (post-hook),它會在 |
|
執行單一步最佳化以更新參數。 |
|
註冊一個最佳化器 step 前置鉤子 (pre-hook),它會在最佳化器 step 之前被呼叫。 |
|
註冊一個最佳化器 step 後置鉤子 (post-hook),它會在最佳化器 step 之後被呼叫。 |
|
重置所有最佳化的 |
演算法¶
實作 Adadelta 演算法。 |
|
實作 Adafactor 演算法。 |
|
實作 Adagrad 演算法。 |
|
實作 Adam 演算法。 |
|
實作 AdamW 演算法。 |
|
SparseAdam 實作 Adam 演算法的遮罩版本,適用於稀疏梯度。 |
|
實作 Adamax 演算法 (基於無窮範數的 Adam 變體)。 |
|
實作平均隨機梯度下降法。 |
|
實作 L-BFGS 演算法。 |
|
實作 NAdam 演算法。 |
|
實作 RAdam 演算法。 |
|
實作 RMSprop 演算法。 |
|
實作彈性反向傳播演算法。 |
|
實作隨機梯度下降法 (可選地帶有動量)。 |
我們的許多演算法都有針對效能、可讀性和/或通用性優化的各種實作,因此如果使用者沒有指定任何特定的實作,我們會嘗試預設使用目前裝置上通常最快的實作。
我們有 3 個主要的實作類別:for-loop、foreach (multi-tensor) 和 fused。 最直接的實作是透過具有大量計算區塊的參數進行 for 迴圈。 For-looping 通常比我們的 foreach 實作慢,後者將參數組合到一個 multi-tensor 中,並一次執行大量的計算區塊,從而節省了許多連續的 kernel 呼叫。 我們的一些最佳化器甚至具有更快的 fused 實作,這些實作將大量的計算區塊融合到一個 kernel 中。 我們可以將 foreach 實作視為水平融合,將 fused 實作視為在其之上的垂直融合。
一般來說,3 種實作的效能順序是 fused > foreach > for-loop。 因此,在適用的情況下,我們預設使用 foreach 而不是 for-loop。 適用表示 foreach 實作可用、使用者未指定任何特定於實作的 kwargs(例如,fused、foreach、differentiable),並且所有張量都是原生的。 請注意,雖然 fused 應該比 foreach 更快,但這些實作比較新,我們希望在全面切換之前給予它們更多的燒機時間。 我們在下面的第二個表中總結了每個實作的穩定性狀態,歡迎您嘗試它們!
下表顯示了每個演算法的可用和預設實作
演算法 |
預設 |
有 foreach 嗎? |
有 fused 嗎? |
---|---|---|---|
foreach |
是 |
否 |
|
for-loop |
否 |
否 |
|
foreach |
是 |
是 (僅限 cpu) |
|
foreach |
是 |
是 |
|
foreach |
是 |
是 |
|
for-loop |
否 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
for-loop |
否 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
否 |
|
foreach |
是 |
是 |
下表顯示了 fused 實作的穩定性狀態
演算法 |
CPU |
CUDA |
MPS |
---|---|---|---|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
beta |
不支援 |
不支援 |
|
beta |
穩定 |
beta |
|
beta |
穩定 |
beta |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
不支援 |
不支援 |
不支援 |
|
beta |
beta |
beta |
如何調整學習率¶
torch.optim.lr_scheduler.LRScheduler
提供了幾種根據 epoch 數量調整學習率的方法。 torch.optim.lr_scheduler.ReduceLROnPlateau
允許根據一些驗證量測動態降低學習率。
學習率排程應該在最佳化器的更新之後應用;例如,您應該這樣編寫您的程式碼
範例
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
大多數學習率排程器可以背靠背呼叫 (也稱為鏈式排程器)。 這樣做的結果是,每個排程器都按順序應用於前一個排程器獲得的學習率。
範例
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
在文件中的許多地方,我們將使用以下範本來引用排程器演算法。
>>> scheduler = ...
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
警告
在 PyTorch 1.1.0 之前,學習率排程器預計在最佳化器的更新之前被呼叫; 1.1.0 以 BC 中斷的方式改變了這種行為。 如果您在最佳化器的更新 (呼叫 optimizer.step()
) 之前使用學習率排程器 (呼叫 scheduler.step()
),這將跳過學習率排程的第一個值。 如果您在升級到 PyTorch 1.1.0 後無法重現結果,請檢查您是否在錯誤的時間呼叫 scheduler.step()
。
在最佳化期間調整學習率。 |
|
設定初始學習率。 |
|
將每個參數群組的學習率乘以指定函數中給定的因子。 |
|
每隔 step_size 個 epochs,將每個參數群組的學習率衰減 gamma 倍。 |
|
一旦 epoch 數達到其中一個里程碑,就將每個參數群組的學習率衰減 gamma 倍。 |
|
將每個參數群組的學習率乘以一個小的常數因子。 |
|
通過線性改變小的乘法因子來衰減每個參數組的學習率。 |
|
每個 epoch 將每個參數組的學習率衰減 gamma 倍。 |
|
使用給定的 total_iters 中的多項式函數衰減每個參數組的學習率。 |
|
使用餘弦退火排程設定每個參數群組的學習率。 |
|
將一系列學習率排程器串聯起來。 |
|
包含一個在優化過程中預期會依序呼叫的排程器列表。 |
|
當某個指標停止改善時,降低學習率。 |
|
根據循環學習率策略 (CLR) 設定每個參數群組的學習率。 |
|
根據 1cycle 學習率策略設定每個參數群組的學習率。 |
|
使用餘弦退火排程設定每個參數群組的學習率。 |
如何使用具名參數載入 optimizer 狀態字典¶
如果存在,函數 load_state_dict()
會儲存從載入的狀態字典中取得的可選 param_names
內容。但是,載入 optimizer 狀態的過程不會受到影響,因為參數的順序對於維持相容性非常重要 (以防順序不同)。若要使用從載入的狀態字典中載入的參數名稱,需要根據所需的行為實作一個自訂的 register_load_state_dict_pre_hook
。
這可能很有用,例如,當模型架構發生變化,但權重和 optimizer 狀態需要保持不變時。以下範例示範如何實作此自訂。
範例
class OneLayerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 4)
def forward(self, x):
return self.fc(x)
model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)
假設 model
實作了一個專家模型 (MoE),並且我們想要複製它並恢復對兩個專家的訓練,兩者的初始化方式都與 fc
層相同。對於以下 model2
,我們建立兩個與 fc
相同的層,並透過將 model
的模型權重和 optimizer 狀態載入到 model2
的 fc1
和 fc2
中來恢復訓練(並相應地調整它們)
class TwoLayerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(3, 4)
def forward(self, x):
return (self.fc1(x) + self.fc2(x)) / 2
model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
若要使用先前 optimizer 的狀態字典載入 optimizer2
的狀態字典,以便 fc1
和 fc2
都以 fc
optimizer 狀態的副本初始化 (以便從 fc
恢復每層的訓練),我們可以使用以下 hook
def adapt_state_dict_ids(optimizer, state_dict):
adapted_state_dict = deepcopy(optimizer.state_dict())
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
for k, v in state_dict['param_groups'][0].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][0][k] = v
lookup_dict = {
'fc1.weight': 'fc.weight',
'fc1.bias': 'fc.bias',
'fc2.weight': 'fc.weight',
'fc2.bias': 'fc.bias'
}
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][0]['params'],
optimizer.state_dict()['param_groups'][0]['param_names']):
name_in_loaded = lookup_dict[param_name]
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
這確保了在模型載入期間將使用針對 model2
層的正確狀態調整後的 state_dict。請注意,此程式碼專為此範例設計 (例如,假設單一參數群組),其他情況可能需要不同的調整。
以下範例示範了當模型結構發生變化時,如何處理載入的 state dict
中缺少的參數。Model_bypass
新增了一個新的 bypass
層,該層不存在於原始 Model1
中。為了恢復訓練,使用自訂的 adapt_state_dict_missing_param
hook 來調整 optimizer 的 state_dict
,確保現有參數被正確映射,而缺少的參數 (如 bypass 層) 保持不變 (如本範例中初始化)。儘管模型發生變化,此方法也能夠順利載入和恢復 optimizer 狀態。新的 bypass 層將從頭開始訓練
class Model1(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 5)
def forward(self, x):
return self.fc(x) + x
model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)
class Model_bypass(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 5)
self.bypass = nn.Linear(5, 5, bias=False)
torch.nn.init.eye_(self.bypass.weight)
def forward(self, x):
return self.fc(x) + self.bypass(x)
model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
def adapt_state_dict_missing_param(optimizer, state_dict):
adapted_state_dict = deepcopy(optimizer.state_dict())
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
for k, v in state_dict['param_groups'][0].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][0][k] = v
lookup_dict = {
'fc.weight': 'fc.weight',
'fc.bias': 'fc.bias',
'bypass.weight': None,
}
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][0]['params'],
optimizer.state_dict()['param_groups'][0]['param_names']):
name_in_loaded = lookup_dict[param_name]
if name_in_loaded in state_dict['param_groups'][0]['param_names']:
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
作為第三個範例,這個 hook 可以用來根據參數的名稱載入狀態,而不是根據參數的順序載入狀態 (預設方法)。
def names_matching(optimizer, state_dict):
assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
adapted_state_dict = deepcopy(optimizer.state_dict())
for g_ind in range(len(state_dict['param_groups'])):
assert len(state_dict['param_groups'][g_ind]['params']) == len(
optimizer.state_dict()['param_groups'][g_ind]['params'])
for k, v in state_dict['param_groups'][g_ind].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][g_ind][k] = v
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][g_ind]['params'],
optimizer.state_dict()['param_groups'][g_ind]['param_names']):
index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
權重平均 (SWA 和 EMA)¶
torch.optim.swa_utils.AveragedModel
實作了隨機權重平均 (SWA) 和指數移動平均 (EMA),torch.optim.swa_utils.SWALR
實作了 SWA 學習率排程器,而 torch.optim.swa_utils.update_bn()
是一個公用程式函數,用於在訓練結束時更新 SWA/EMA 批次正規化統計資料。
SWA 已在 平均權重會產生更廣泛的 Optima 和更好的泛化效果 中提出。
EMA 是一種廣為人知的技術,透過減少所需的權重更新次數來縮短訓練時間。它是 Polyak 平均 的一種變體,但使用指數權重而不是跨迭代的相等權重。
建構平均模型¶
AveragedModel 類別用於計算 SWA 或 EMA 模型的權重。
您可以透過執行以下操作來建立 SWA 平均模型
>>> averaged_model = AveragedModel(model)
EMA 模型透過指定 multi_avg_fn
參數來建構,如下所示
>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
Decay 是一個介於 0 和 1 之間的參數,用於控制平均參數衰減的速度。如果沒有提供給 torch.optim.swa_utils.get_ema_multi_avg_fn()
,則預設值為 0.999。 Decay 值應接近 1.0,因為較小的值可能會導致最佳化收斂問題。
torch.optim.swa_utils.get_ema_multi_avg_fn()
傳回一個將以下 EMA 方程式套用至權重的函數
其中 alpha 為 EMA 衰減率。
這裡的 model
可以是任意的 torch.nn.Module
物件。averaged_model
將會追蹤 model
參數的滾動平均值。 為了更新這些平均值,您應該在 optimizer.step() 之後使用 update_parameters()
函數
>>> averaged_model.update_parameters(model)
對於 SWA 和 EMA,此呼叫通常在優化器 step()
之後立即完成。在 SWA 的情況下,通常會在訓練開始時跳過一些步驟。
自訂平均策略¶
預設情況下,torch.optim.swa_utils.AveragedModel
計算您提供的參數的滾動等權平均值,但您也可以使用具有 avg_fn
或 multi_avg_fn
參數的自訂平均函數
avg_fn
允許定義一個作用於每個參數元組(平均參數,模型參數)的函數,並且應該返回新的平均參數。multi_avg_fn
允許定義更有效率的操作,同時作用於參數列表的元組(平均參數列表,模型參數列表),例如使用torch._foreach*
函數。 此函數必須就地更新平均參數。
在下面的範例中,ema_model
使用 avg_fn
參數計算指數移動平均
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
在下面的範例中,ema_model
使用更有效率的 multi_avg_fn
參數計算指數移動平均
>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))
SWA 學習率排程¶
通常,在 SWA 中,學習率設定為高的常數值。SWALR
是一個學習率排程器,它將學習率退火到一個固定值,然後保持不變。 例如,以下程式碼建立一個排程器,該排程器在每個參數群組中於 5 個 epoch 內將學習率從其初始值線性退火到 0.05
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
您也可以使用餘弦退火到一個固定值,而不是線性退火,方法是設定 anneal_strategy="cos"
。
注意批次正規化¶
update_bn()
是一個實用函數,允許在訓練結束時計算給定資料載入器 loader
上 SWA 模型的批次正規化統計量
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
update_bn()
將 swa_model
應用於資料載入器中的每個元素,並計算模型中每個批次正規化層的啟用統計量。
警告
update_bn()
假設資料載入器 loader
中的每個批次都是張量或張量列表,其中第一個元素是網路 swa_model
應用的張量。 如果您的資料載入器具有不同的結構,您可以通過對資料集的每個元素使用 swa_model
進行前向傳遞來更新 swa_model
的批次正規化統計量。
整合在一起:SWA¶
在下面的範例中,swa_model
是累積權重平均值的 SWA 模型。 我們總共訓練模型 300 個 epoch,並且我們切換到 SWA 學習率排程,並在 epoch 160 開始收集參數的 SWA 平均值
>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if epoch > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)
整合在一起:EMA¶
在下面的範例中,ema_model
是一個 EMA 模型,它使用 0.999 的衰減率來累計權重的指數衰減平均值。 我們總共訓練模型 300 個 epoch,並立即開始收集 EMA 平均值。
>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)
實現隨機權重平均 (Stochastic Weight Averaging, SWA) 和指數移動平均 (Exponential Moving Average, EMA) 的平均模型。 |
|
將每個參數群組中的學習率退火到一個固定值。 |
- torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source][source]¶
獲取在多個參數上應用指數移動平均 (EMA) 的函數。
- torch.optim.swa_utils.update_bn(loader, model, device=None)[source][source]¶
更新模型中的 BatchNorm running_mean 和 running_var 緩衝區。
它在 loader 中的數據上執行一次遍歷,以估計模型中 BatchNorm 層的激活統計信息。
- 參數
loader (torch.utils.data.DataLoader) – 用於計算激活統計信息的資料集載入器。每個數據批次應該是一個張量,或者是一個列表/元組,其第一個元素是一個包含數據的張量。
model (torch.nn.Module) – 我們尋求更新 BatchNorm 統計信息的模型。
device (torch.device, optional) – 如果設定,資料將在傳遞到
model
之前傳輸到device
。
範例
>>> loader, model = ... >>> torch.optim.swa_utils.update_bn(loader, model)
注意
update_bn 工具假設
loader
中的每個資料批次都是一個張量,或者是一個張量列表或元組;在後一種情況下,假設應該在與數據批次對應的列表或元組的第一個元素上調用model.forward()
。