torch.optim.Optimizer.state_dict¶
- Optimizer.state_dict()[來源][來源]¶
以
dict
形式傳回優化器的狀態。它包含兩個條目
state
:一個保存目前優化狀態的 Dict。其內容不同優化器類別之間有所差異,但仍有一些共同特性。 例如,狀態是針對每個參數儲存的,而參數本身則**不會**儲存。
state
是一個字典,將參數 ID 映射到一個字典,其中包含與每個參數對應的狀態。
param_groups
:一個包含所有參數群組的列表,其中每個參數群組都是一個字典。 每個參數群組都包含特定於優化器的元數據,例如學習率和權重衰減,以及群組中參數的參數 ID 列表。 如果參數群組是使用
named_parameters()
初始化的,則名稱內容也將保存在狀態字典中。
注意:參數 ID 看起來可能像索引,但它們只是將狀態與 param_group 關聯的 ID。 從 state_dict 載入時,優化器將壓縮 param_group 的
params
(整數 ID)和優化器的param_groups
(實際的nn.Parameter
),以便在沒有額外驗證的情況下匹配狀態。一個返回的 state dict 可能看起來像這樣
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }