建立 TorchScript 模組¶
TorchScript 是一種從 PyTorch 程式碼建立可序列化和可最佳化模型的方法。PyTorch 有關於如何執行此操作的詳細文件 https://pytorch.dev.org.tw/tutorials/beginner/Intro_to_TorchScript_tutorial.html,但簡而言之,以下是關鍵背景資訊和流程
PyTorch 程式基於 Module
,可用於組成更高階的模組。Modules
包含一個建構子來設定模組、參數和子模組,以及一個 forward 函數,用於描述在調用模組時如何使用參數和子模組。
例如,我們可以像這樣定義一個 LeNet 模組
1import torch.nn as nn
2import torch.nn.functional as F
3
4
5class LeNetFeatExtractor(nn.Module):
6 def __init__(self):
7 super(LeNetFeatExtractor, self).__init__()
8 self.conv1 = nn.Conv2d(1, 6, 3)
9 self.conv2 = nn.Conv2d(6, 16, 3)
10
11 def forward(self, x):
12 x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13 x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14 return x
15
16
17class LeNetClassifier(nn.Module):
18 def __init__(self):
19 super(LeNetClassifier, self).__init__()
20 self.fc1 = nn.Linear(16 * 6 * 6, 120)
21 self.fc2 = nn.Linear(120, 84)
22 self.fc3 = nn.Linear(84, 10)
23
24 def forward(self, x):
25 x = torch.flatten(x, 1)
26 x = F.relu(self.fc1(x))
27 x = F.relu(self.fc2(x))
28 x = self.fc3(x)
29 return x
30
31
32class LeNet(nn.Module):
33 def __init__(self):
34 super(LeNet, self).__init__()
35 self.feat = LeNetFeatExtractor()
36 self.classifier = LeNetClassifier()
37
38 def forward(self, x):
39 x = self.feat(x)
40 x = self.classifier(x)
41 return x
.
顯然,您可能想要將如此簡單的模型整合到單一模組中,但我們可以在此處看到 PyTorch 的可組合性
從這裡開始,有兩種從 PyTorch Python 程式碼轉換為 TorchScript 程式碼的路徑:追蹤和腳本編寫。
追蹤會追蹤模組被呼叫時的執行路徑,並記錄發生的事情。若要追蹤 LeNet 模組的實例,我們可以呼叫 torch.jit.trace
並提供範例輸入。
import torch
model = LeNet()
input_data = torch.empty([1, 1, 32, 32])
traced_model = torch.jit.trace(model, input_data)
腳本編寫實際上會使用編譯器檢查您的程式碼,並產生等效的 TorchScript 程式。不同之處在於,由於追蹤會追蹤模組的執行,因此它無法擷取控制流程等。透過從 Python 程式碼開始,編譯器可以包含這些組件。我們可以透過呼叫 torch.jit.script
在我們的 LeNet 模組上執行腳本編寫編譯器
import torch
model = LeNet()
script_model = torch.jit.script(model)
有許多原因可以選擇其中一種路徑,PyTorch 文件中有關於如何選擇的資訊。從 Torch-TensorRT 的角度來看,追蹤模組有更好的支援(即,您的模組更可能編譯),因為它不包含完整程式語言的所有複雜性,儘管兩種路徑都支援。
在腳本編寫或追蹤您的模組之後,您會獲得一個 TorchScript 模組。這包含用於執行模組的程式碼和參數,這些程式碼和參數儲存在 Torch-TensorRT 可以使用的中間表示中。
以下是 LeNet 追蹤模組 IR 的外觀
graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
%input.1 : Float(1, 1, 32, 32)):
%129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifier"](%self.1)
%119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
%137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
%138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
return (%138)
以及 LeNet 腳本編寫模組 IR
graph(%self : __torch__.LeNet,
%x.1 : Tensor):
%2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
%x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # x.py:38:12
%5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifier"](%self)
%x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) # x.py:39:12
return (%x.5)
您可以看到 IR 保留了我們在 Python 程式碼中的模組結構。
在 Python 中使用 TorchScript¶
TorchScript 模組的執行方式與您執行一般 PyTorch 模組的方式相同。您可以使用 forward
方法執行前向傳遞,或僅呼叫模組 torch_script_module(in_tensor)
。JIT 編譯器將即時編譯和最佳化模組,然後傳回結果。
將 TorchScript 模組儲存到磁碟¶
對於追蹤或腳本編寫模組,您可以使用以下命令將模組儲存到磁碟
import torch
model = LeNet()
script_model = torch.jit.script(model)
script_model.save("lenet_scripted.ts")