快捷方式

PyTorch: optim

建立於:2020 年 12 月 03 日 | 最後更新:2020 年 12 月 03 日 | 最後驗證:未驗證

一個三階多項式,通過最小化歐幾里德距離的平方,訓練以預測 \(y=\sin(x)\)\(-\pi\)\(pi\)

此實作使用 PyTorch 的 nn 套件來建立網路。

我們不手動更新模型的權重,而是使用 optim 套件來定義一個將為我們更新權重的 Optimizer。 optim 套件定義了許多常用於深度學習的優化演算法,包括 SGD+momentum、RMSProp、Adam 等。

import torch
import math


# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)

# Prepare the input tensor (x, x^2, x^3).
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(3, 1),
    torch.nn.Flatten(0, 1)
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use RMSprop; the optim package contains many other
# optimization algorithms. The first argument to the RMSprop constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
for t in range(2000):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(xx)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()


linear_layer = model[0]
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')

腳本總執行時間:(0 分鐘 0.000 秒)

由 Sphinx-Gallery 產生的圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

獲取初學者和進階開發者的深入教學課程

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源