注意
點擊這裡以下載完整的範例程式碼
學習基礎 || 快速入門 || 張量 || 資料集 & 資料載入器 || 轉換 || 建立模型 || 自動微分 || 最佳化 || 儲存 & 載入模型
最佳化模型參數¶
建立於:2021 年 2 月 9 日 | 最後更新:2024 年 1 月 31 日 | 最後驗證:2024 年 11 月 5 日
現在我們有了模型和資料,是時候透過在我們的資料上最佳化模型的參數來訓練、驗證和測試我們的模型了。訓練模型是一個迭代過程;在每次迭代中,模型都會對輸出做出猜測,計算猜測中的誤差 (損失),收集誤差相對於其參數的導數 (正如我們在前一節中所見),並使用梯度下降最佳化這些參數。如需此過程的更詳細演練,請查看此關於來自 3Blue1Brown 的反向傳播的影片。
先決條件程式碼¶
我們從先前的章節載入關於資料集 & 資料載入器和建立模型的程式碼。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
0%| | 0.00/26.4M [00:00<?, ?B/s]
0%| | 65.5k/26.4M [00:00<01:13, 361kB/s]
1%| | 229k/26.4M [00:00<00:38, 676kB/s]
3%|3 | 885k/26.4M [00:00<00:12, 2.04MB/s]
10%|9 | 2.56M/26.4M [00:00<00:04, 5.95MB/s]
25%|##5 | 6.62M/26.4M [00:00<00:01, 12.9MB/s]
39%|###9 | 10.4M/26.4M [00:00<00:00, 19.0MB/s]
60%|#####9 | 15.8M/26.4M [00:01<00:00, 23.7MB/s]
71%|#######1 | 18.8M/26.4M [00:01<00:00, 25.1MB/s]
94%|#########3| 24.8M/26.4M [00:01<00:00, 28.1MB/s]
100%|##########| 26.4M/26.4M [00:01<00:00, 19.2MB/s]
0%| | 0.00/29.5k [00:00<?, ?B/s]
100%|##########| 29.5k/29.5k [00:00<00:00, 326kB/s]
0%| | 0.00/4.42M [00:00<?, ?B/s]
1%|1 | 65.5k/4.42M [00:00<00:12, 360kB/s]
5%|5 | 229k/4.42M [00:00<00:06, 680kB/s]
20%|## | 885k/4.42M [00:00<00:01, 2.53MB/s]
44%|####3 | 1.93M/4.42M [00:00<00:00, 4.10MB/s]
100%|##########| 4.42M/4.42M [00:00<00:00, 6.08MB/s]
0%| | 0.00/5.15k [00:00<?, ?B/s]
100%|##########| 5.15k/5.15k [00:00<00:00, 34.4MB/s]
超參數¶
超參數是可以調整的參數,可讓您控制模型最佳化過程。不同的超參數值會影響模型訓練和收斂速度 (閱讀更多關於超參數調整)
- 我們定義以下用於訓練的超參數
Epoch 數 - 疊代資料集的次數
批次大小 - 在更新參數之前,透過網路傳播的資料樣本數
學習率 - 在每個批次/epoch 更新模型參數的程度。較小的值會產生較慢的學習速度,而較大的值可能會導致訓練期間出現不可預測的行為。
learning_rate = 1e-3
batch_size = 64
epochs = 5
最佳化迴圈¶
一旦我們設定了超參數,我們就可以使用最佳化迴圈來訓練和最佳化我們的模型。最佳化迴圈的每次疊代都稱為一個 epoch。
- 每個 epoch 由兩個主要部分組成
訓練迴圈 - 疊代訓練資料集並嘗試收斂到最佳參數。
驗證/測試迴圈 - 疊代測試資料集以檢查模型效能是否正在提高。
讓我們簡要熟悉訓練迴圈中使用的一些概念。跳到前面以查看最佳化迴圈的完整實作。
損失函數¶
當呈現一些訓練資料時,我們未訓練的網路可能無法給出正確的答案。損失函數衡量獲得的結果與目標值的差異程度,它是我們希望在訓練期間最小化的損失函數。為了計算損失,我們使用給定資料樣本的輸入進行預測,並將其與真實資料標籤值進行比較。
常見的損失函數包括nn.MSELoss (均方誤差),用於迴歸任務,以及nn.NLLLoss (負對數似然),用於分類。nn.CrossEntropyLoss 結合了 nn.LogSoftmax
和 nn.NLLLoss
。
我們將模型的輸出 logits 傳遞給 nn.CrossEntropyLoss
,它將正規化 logits 並計算預測誤差。
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
最佳化器¶
最佳化 (Optimization) 是調整模型參數以減少每個訓練步驟中模型誤差的過程。 最佳化演算法 (Optimization algorithms) 定義了如何執行此過程(在此範例中,我們使用隨機梯度下降)。 所有最佳化邏輯都封裝在 optimizer
物件中。 在這裡,我們使用 SGD 最佳化器;此外,PyTorch 中還有許多 不同的最佳化器 可用,例如 ADAM 和 RMSProp,它們更適合不同類型的模型和資料。
我們透過註冊需要訓練的模型參數並傳入學習率超參數來初始化最佳化器。
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
- 在訓練迴圈中,最佳化分為三個步驟:
呼叫
optimizer.zero_grad()
以重置模型參數的梯度。 預設情況下,梯度會累加;為了防止重複計算,我們在每次迭代時都會明確地將其歸零。透過呼叫
loss.backward()
來反向傳播預測損失。 PyTorch 會將損失相對於每個參數的梯度存入。一旦有了梯度,我們就呼叫
optimizer.step()
來透過反向傳遞中收集的梯度調整參數。
完整實作¶
我們定義 train_loop
,它循環執行我們的最佳化程式碼,以及 test_loop
,它評估模型針對我們的測試資料的效能。
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
# Set the model to training mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), batch * batch_size + len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
# Set the model to evaluation mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
我們初始化損失函數和最佳化器,並將其傳遞給 train_loop
和 test_loop
。 隨時增加 epoch 的數量來追蹤模型效能的改進。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.298730 [ 64/60000]
loss: 2.289123 [ 6464/60000]
loss: 2.273286 [12864/60000]
loss: 2.269406 [19264/60000]
loss: 2.249603 [25664/60000]
loss: 2.229407 [32064/60000]
loss: 2.227368 [38464/60000]
loss: 2.204261 [44864/60000]
loss: 2.206193 [51264/60000]
loss: 2.166651 [57664/60000]
Test Error:
Accuracy: 50.9%, Avg loss: 2.166725
Epoch 2
-------------------------------
loss: 2.176750 [ 64/60000]
loss: 2.169595 [ 6464/60000]
loss: 2.117500 [12864/60000]
loss: 2.129272 [19264/60000]
loss: 2.079674 [25664/60000]
loss: 2.032928 [32064/60000]
loss: 2.050115 [38464/60000]
loss: 1.985236 [44864/60000]
loss: 1.987887 [51264/60000]
loss: 1.907162 [57664/60000]
Test Error:
Accuracy: 55.9%, Avg loss: 1.915486
Epoch 3
-------------------------------
loss: 1.951612 [ 64/60000]
loss: 1.928685 [ 6464/60000]
loss: 1.815709 [12864/60000]
loss: 1.841552 [19264/60000]
loss: 1.732467 [25664/60000]
loss: 1.692914 [32064/60000]
loss: 1.701714 [38464/60000]
loss: 1.610632 [44864/60000]
loss: 1.632870 [51264/60000]
loss: 1.514263 [57664/60000]
Test Error:
Accuracy: 58.8%, Avg loss: 1.541525
Epoch 4
-------------------------------
loss: 1.616448 [ 64/60000]
loss: 1.582892 [ 6464/60000]
loss: 1.427595 [12864/60000]
loss: 1.487950 [19264/60000]
loss: 1.359332 [25664/60000]
loss: 1.364817 [32064/60000]
loss: 1.371491 [38464/60000]
loss: 1.298706 [44864/60000]
loss: 1.336201 [51264/60000]
loss: 1.232145 [57664/60000]
Test Error:
Accuracy: 62.2%, Avg loss: 1.260237
Epoch 5
-------------------------------
loss: 1.345538 [ 64/60000]
loss: 1.327798 [ 6464/60000]
loss: 1.153802 [12864/60000]
loss: 1.254829 [19264/60000]
loss: 1.117322 [25664/60000]
loss: 1.153248 [32064/60000]
loss: 1.171765 [38464/60000]
loss: 1.110263 [44864/60000]
loss: 1.154467 [51264/60000]
loss: 1.070921 [57664/60000]
Test Error:
Accuracy: 64.1%, Avg loss: 1.089831
Epoch 6
-------------------------------
loss: 1.166889 [ 64/60000]
loss: 1.170514 [ 6464/60000]
loss: 0.979435 [12864/60000]
loss: 1.113774 [19264/60000]
loss: 0.973411 [25664/60000]
loss: 1.015192 [32064/60000]
loss: 1.051113 [38464/60000]
loss: 0.993591 [44864/60000]
loss: 1.039709 [51264/60000]
loss: 0.971077 [57664/60000]
Test Error:
Accuracy: 65.8%, Avg loss: 0.982440
Epoch 7
-------------------------------
loss: 1.045165 [ 64/60000]
loss: 1.070583 [ 6464/60000]
loss: 0.862304 [12864/60000]
loss: 1.022265 [19264/60000]
loss: 0.885213 [25664/60000]
loss: 0.919528 [32064/60000]
loss: 0.972762 [38464/60000]
loss: 0.918728 [44864/60000]
loss: 0.961629 [51264/60000]
loss: 0.904379 [57664/60000]
Test Error:
Accuracy: 66.9%, Avg loss: 0.910167
Epoch 8
-------------------------------
loss: 0.956964 [ 64/60000]
loss: 1.002171 [ 6464/60000]
loss: 0.779057 [12864/60000]
loss: 0.958409 [19264/60000]
loss: 0.827240 [25664/60000]
loss: 0.850262 [32064/60000]
loss: 0.917320 [38464/60000]
loss: 0.868384 [44864/60000]
loss: 0.905506 [51264/60000]
loss: 0.856353 [57664/60000]
Test Error:
Accuracy: 68.3%, Avg loss: 0.858248
Epoch 9
-------------------------------
loss: 0.889765 [ 64/60000]
loss: 0.951220 [ 6464/60000]
loss: 0.717035 [12864/60000]
loss: 0.911042 [19264/60000]
loss: 0.786085 [25664/60000]
loss: 0.798370 [32064/60000]
loss: 0.874939 [38464/60000]
loss: 0.832796 [44864/60000]
loss: 0.863254 [51264/60000]
loss: 0.819742 [57664/60000]
Test Error:
Accuracy: 69.5%, Avg loss: 0.818780
Epoch 10
-------------------------------
loss: 0.836395 [ 64/60000]
loss: 0.910220 [ 6464/60000]
loss: 0.668506 [12864/60000]
loss: 0.874338 [19264/60000]
loss: 0.754805 [25664/60000]
loss: 0.758453 [32064/60000]
loss: 0.840451 [38464/60000]
loss: 0.806153 [44864/60000]
loss: 0.830360 [51264/60000]
loss: 0.790281 [57664/60000]
Test Error:
Accuracy: 71.0%, Avg loss: 0.787271
Done!