快捷方式

學習基礎 || 快速入門 || 張量 || 資料集 & 資料載入器 || 轉換 || 建置模型 || 自動微分 || 最佳化 || 儲存 & 載入模型

儲存和載入模型

建立於:2021 年 2 月 9 日 | 最後更新:2024 年 10 月 15 日 | 最後驗證:2024 年 11 月 05 日

在本節中,我們將研究如何透過儲存、載入和執行模型預測來保存模型狀態。

import torch
import torchvision.models as models

儲存和載入模型權重

PyTorch 模型將學習到的參數儲存在一個內部狀態字典中,稱為 state_dict。這些可以透過 torch.save 方法來保存。

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  4%|3         | 20.6M/528M [00:00<00:02, 216MB/s]
  8%|7         | 41.9M/528M [00:00<00:02, 219MB/s]
 12%|#1        | 63.1M/528M [00:00<00:02, 221MB/s]
 16%|#5        | 84.4M/528M [00:00<00:02, 222MB/s]
 20%|##        | 106M/528M [00:00<00:01, 222MB/s]
 24%|##4       | 127M/528M [00:00<00:01, 222MB/s]
 28%|##8       | 148M/528M [00:00<00:01, 223MB/s]
 32%|###2      | 170M/528M [00:00<00:01, 223MB/s]
 36%|###6      | 191M/528M [00:00<00:01, 223MB/s]
 40%|####      | 212M/528M [00:01<00:01, 223MB/s]
 44%|####4     | 234M/528M [00:01<00:01, 223MB/s]
 48%|####8     | 255M/528M [00:01<00:01, 223MB/s]
 52%|#####2    | 276M/528M [00:01<00:01, 223MB/s]
 56%|#####6    | 298M/528M [00:01<00:01, 223MB/s]
 60%|######    | 319M/528M [00:01<00:00, 223MB/s]
 65%|######4   | 341M/528M [00:01<00:00, 223MB/s]
 69%|######8   | 362M/528M [00:01<00:00, 223MB/s]
 73%|#######2  | 383M/528M [00:01<00:00, 223MB/s]
 77%|#######6  | 405M/528M [00:01<00:00, 223MB/s]
 81%|########  | 426M/528M [00:02<00:00, 223MB/s]
 85%|########4 | 448M/528M [00:02<00:00, 223MB/s]
 89%|########8 | 469M/528M [00:02<00:00, 223MB/s]
 93%|#########2| 490M/528M [00:02<00:00, 223MB/s]
 97%|#########6| 511M/528M [00:02<00:00, 223MB/s]
100%|##########| 528M/528M [00:02<00:00, 223MB/s]

要載入模型權重,您需要先建立同一個模型的實例,然後使用 load_state_dict() 方法載入參數。

在下面的程式碼中,我們設定 weights_only=True 以限制在反序列化期間執行的函數,僅限於載入權重所需的函數。 使用 weights_only=True 被認為是載入權重時的最佳實踐。

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意

請務必在推論前呼叫 model.eval() 方法,將 dropout 和批次正規化層設定為評估模式。 如果不這樣做,將會產生不一致的推論結果。

儲存和載入具有形狀的模型

在載入模型權重時,我們需要先實例化模型類別,因為該類別定義了網路的結構。 我們可能希望將這個類別的結構與模型一起儲存,在這種情況下,我們可以將 model (而不是 model.state_dict())傳遞給儲存函數。

torch.save(model, 'model.pth')

然後,我們可以如下所示載入模型。

儲存和載入 torch.nn.Modules 中所述,儲存 state_dict 被認為是最佳實踐。 但是,下面我們使用 weights_only=False,因為這涉及載入模型,這是 torch.save 的舊版用例。

model = torch.load('model.pth', weights_only=False),

注意

此方法在序列化模型時使用 Python pickle 模組,因此它依賴於載入模型時可用的實際類別定義。

文件

獲取 PyTorch 的完整開發者文件

查看文件

教學課程

獲取適合初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源