注意
點擊 這裡 下載完整的範例程式碼
PyTorch: 控制流 + 權重共享¶
建立於:2017 年 3 月 24 日 | 最後更新:2021 年 12 月 28 日 | 最後驗證:2024 年 11 月 05 日
為了展示 PyTorch 動態圖的強大功能,我們將實現一個非常奇怪的模型:一個三到五階的多項式,它在每次前向傳遞中選擇一個介於 4 和 5 之間的隨機數,並使用這麼多的階數,多次重複使用相同的權重來計算第四階和第五階。
import random
import torch
import math
class DynamicNet(torch.nn.Module):
def __init__(self):
"""
In the constructor we instantiate five parameters and assign them as members.
"""
super().__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))
self.e = torch.nn.Parameter(torch.randn(()))
def forward(self, x):
"""
For the forward pass of the model, we randomly choose either 4, 5
and reuse the e parameter to compute the contribution of these orders.
Since each forward pass builds a dynamic computation graph, we can use normal
Python control-flow operators like loops or conditional statements when
defining the forward pass of the model.
Here we also see that it is perfectly safe to reuse the same parameter many
times when defining a computational graph.
"""
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
for exp in range(4, random.randint(4, 6)):
y = y + self.e * x ** exp
return y
def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Construct our model by instantiating the class defined above
model = DynamicNet()
# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
for t in range(30000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 2000 == 1999:
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Result: {model.string()}')
腳本總執行時間: ( 0 分鐘 0.000 秒)