快捷鍵

TorchScript 中的模型凍結

建立於:2020 年 7 月 28 日 | 最後更新:2024 年 12 月 02 日 | 最後驗證:2024 年 11 月 05 日

警告

TorchScript 已不再積極開發。

在本教學中,我們介紹 TorchScript 中模型凍結的語法。凍結是將 Pytorch 模組參數和屬性值內聯到 TorchScript 內部表示中的過程。參數和屬性值被視為最終值,並且不能在產生的凍結模組中修改。

基本語法

可以使用下面的 API 調用模型凍結

torch.jit.freeze(mod : ScriptModule, names : str[]) -> ScriptModule

請注意,輸入模組可以是腳本編寫或追蹤的結果。請參閱https://pytorch.dev.org.tw/tutorials/beginner/Intro_to_TorchScript_tutorial.html

接下來,我們將使用一個範例演示凍結的工作原理

import torch, time

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout2d(0.25)
        self.dropout2 = torch.nn.Dropout2d(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = torch.nn.functional.log_softmax(x, dim=1)
        return output

    @torch.jit.export
    def version(self):
        return 1.0

net = torch.jit.script(Net())
fnet = torch.jit.freeze(net)

print(net.conv1.weight.size())
print(net.conv1.bias)

try:
    print(fnet.conv1.bias)
    # without exception handling, prints:
    # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
    # with name 'conv1'
except RuntimeError:
    print("field 'conv1' is inlined. It does not exist in 'fnet'")

try:
    fnet.version()
    # without exception handling, prints:
    # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
    # with name 'version'
except RuntimeError:
    print("method 'version' is not deleted in fnet. Only 'forward' is preserved")

fnet2 = torch.jit.freeze(net, ["version"])

print(fnet2.version())

B=1
warmup = 1
iter = 1000
input = torch.rand(B, 1,28, 28)

start = time.time()
for i in range(warmup):
    net(input)
end = time.time()
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(warmup):
    fnet(input)
end = time.time()
print("Frozen   - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
    input = torch.rand(B, 1,28, 28)
    net(input)
end = time.time()
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
    input = torch.rand(B, 1,28, 28)
    fnet2(input)
end = time.time()
print("Frozen    - Inference time: {0:5.2f}".format(end-start), flush =True)

在我的機器上,我測量了時間

  • 腳本編寫 - 預熱時間:0.0107

  • 凍結 - 預熱時間:0.0048

  • 腳本編寫 - 推理:1.35

  • 凍結 - 推理時間:1.17

在我們的範例中,預熱時間測量了前兩次執行。凍結模型比腳本編寫模型快 50%。在一些更複雜的模型上,我們觀察到預熱時間的加速甚至更高。凍結實現了這種加速,因為它正在做 TorchScript 在啟動前幾次執行時必須做的一些工作。

推理時間測量模型預熱後的推理執行時間。雖然我們觀察到執行時間的顯著變化,但凍結模型通常比腳本編寫模型快約 15%。當輸入更大時,我們觀察到較小的加速,因為執行由張量運算主導。

結論

在本教學中,我們學習了模型凍結。凍結是一種有用的技術,可以優化模型以進行推理,並且還可以顯著減少 TorchScript 預熱時間。

腳本的總執行時間: (0 分鐘 0.000 秒)

由 Sphinx-Gallery 產生的圖庫


評價本教學

© Copyright 2024, PyTorch。

使用 Sphinx 構建,主題由 theme 提供,並由 Read the Docs 提供。

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源