捷徑

TorchScript 簡介

建立於:2019 年 8 月 9 日 | 最後更新:2024 年 12 月 2 日 | 最後驗證:2024 年 11 月 5 日

作者:James Reed (jamesreed@fb.com), Michael Suo (suo@fb.com), rev2

警告

TorchScript 已停止積極開發。

本教學是 TorchScript 的簡介,TorchScript 是 PyTorch 模型(nn.Module 的子類)的中間表示形式,然後可以在高效能環境(例如 C++)中運行。

在本教學中,我們將涵蓋

  1. 在 PyTorch 中撰寫模型的基本知識,包括

  • 模組

  • 定義 forward 函數

  • 將模組組合成模組層次結構

  1. 將 PyTorch 模組轉換為 TorchScript(我們的高效能部署運行時)的特定方法

  • 追蹤現有模組

  • 使用腳本直接編譯模組

  • 如何組合這兩種方法

  • 儲存和載入 TorchScript 模組

我們希望您在完成本教學後,繼續學習 後續教學,該教學將引導您完成從 C++ 實際呼叫 TorchScript 模型的範例。

import torch  # This is all you need to use both PyTorch and TorchScript!
print(torch.__version__)
torch.manual_seed(191009)  # set the seed for reproducibility
2.6.0+cu124

<torch._C.Generator object at 0x7fd64938a3d0>

PyTorch 模型撰寫的基本知識

讓我們從定義一個簡單的 Module 開始。Module 是 PyTorch 中組合的基本單位。它包含

  1. 一個建構函式,用於準備模組以供調用

  2. 一組 Parameters 和子 Modules。這些由建構函式初始化,並且可以在模組調用期間使用。

  3. 一個 forward 函數。這是模組被調用時運行的程式碼。

讓我們檢查一個小範例

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))
(tensor([[0.8219, 0.8990, 0.6670, 0.8277],
        [0.5176, 0.4017, 0.8545, 0.7336],
        [0.6013, 0.6992, 0.2618, 0.6668]]), tensor([[0.8219, 0.8990, 0.6670, 0.8277],
        [0.5176, 0.4017, 0.8545, 0.7336],
        [0.6013, 0.6992, 0.2618, 0.6668]]))

所以我們

  1. 建立了一個繼承 torch.nn.Module 的類別。

  2. 定義了一個建構函式。建構函式沒有做太多事情,只是呼叫 super 的建構函式。

  3. 定義了一個 forward 函數,它接受兩個輸入並返回兩個輸出。forward 函數的實際內容並不重要,但它有點像假的 RNN cell – 也就是說 – 它是一個在迴圈上應用的函數。

我們實例化了模組,並建立了 xh,它們只是 3x4 的隨機值矩陣。然後我們使用 my_cell(x, h) 調用了 cell。這反過來呼叫我們的 forward 函數。

讓我們做一些更有趣的事情

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8573,  0.6190,  0.5774,  0.7869],
        [ 0.3326,  0.0530,  0.0702,  0.8114],
        [ 0.7818, -0.0506,  0.4039,  0.7967]], grad_fn=<TanhBackward0>), tensor([[ 0.8573,  0.6190,  0.5774,  0.7869],
        [ 0.3326,  0.0530,  0.0702,  0.8114],
        [ 0.7818, -0.0506,  0.4039,  0.7967]], grad_fn=<TanhBackward0>))

我們重新定義了我們的模組 MyCell,但這次我們新增了一個 self.linear 屬性,並且我們在 forward 函數中調用 self.linear

這裡到底發生了什麼?torch.nn.Linear 是 PyTorch 標準庫中的一個 Module。就像 MyCell 一樣,可以使用呼叫語法來調用它。我們正在建立 Module 的層次結構。

Module 上執行 print 將會提供 Module 的子類層次結構的可視化表示。在我們的範例中,我們可以看到我們的 Linear 子類及其參數。

透過以這種方式組合 Module,我們可以簡潔且可讀地撰寫具有可重複使用元件的模型。

您可能已經注意到輸出上的 grad_fn。這是 PyTorch 的自動微分方法(稱為 autograd)的詳細資訊。簡而言之,這個系統允許我們計算透過可能複雜的程式的導數。該設計允許在模型撰寫方面具有大量的彈性。

現在讓我們檢視一下這種彈性

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8346,  0.5931,  0.2097,  0.8232],
        [ 0.2340, -0.1254,  0.2679,  0.8064],
        [ 0.6231,  0.1494, -0.3110,  0.7865]], grad_fn=<TanhBackward0>), tensor([[ 0.8346,  0.5931,  0.2097,  0.8232],
        [ 0.2340, -0.1254,  0.2679,  0.8064],
        [ 0.6231,  0.1494, -0.3110,  0.7865]], grad_fn=<TanhBackward0>))

我們再次重新定義了 MyCell 類別,但這次我們定義了 MyDecisionGate。這個模組使用了控制流。控制流包含迴圈和 if 語句等。

許多框架採用給定完整程式表示式後,計算符號導數的方法。 然而,在 PyTorch 中,我們使用梯度帶 (gradient tape)。 我們記錄操作發生的過程,並反向重播它們來計算導數。 透過這種方式,框架不必顯式地為語言中的所有結構定義導數。

How autograd works

Autograd 的運作方式

TorchScript 基礎知識

現在讓我們來看一個執行中的範例,看看如何應用 TorchScript。

簡而言之,即使 PyTorch 具有彈性和動態特性,TorchScript 也能提供工具來捕捉模型的定義。 我們先從檢視稱為 追蹤 (tracing) 的內容開始。

追蹤 Modules

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)

(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

我們稍微回溯了一下,採用了 MyCell 類別的第二個版本。 和之前一樣,我們實例化了它,但這次,我們呼叫了 torch.jit.trace,傳入了 Module,並且傳入了網路可能會看到的範例輸入

這到底做了什麼? 它呼叫了 Module,記錄了執行 Module 時發生的操作,並建立了 torch.jit.ScriptModule 的實例 (其中 TracedModule 是一個實例)。

TorchScript 將其定義記錄在一個中繼表示式 (Intermediate Representation, IR) 中,在深度學習中通常稱為圖 (graph)。 我們可以使用 .graph 屬性來檢視該圖。

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:191:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:191:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:191:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

但是,這是一個非常底層的表示式,並且圖中包含的大部分資訊對最終使用者沒有用處。 相反地,我們可以使用 .code 屬性來提供程式碼的 Python 語法解釋。

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)

那麼,我們為什麼要這麼做? 有幾個原因:

  1. TorchScript 程式碼可以在它自己的直譯器中呼叫,這基本上是一個受限的 Python 直譯器。 這個直譯器不會取得全域直譯器鎖 (Global Interpreter Lock),因此可以在同一個實例上同時處理許多請求。

  2. 這種格式讓我們可以將整個模型儲存到磁碟,並將其載入到另一個環境中,例如用 Python 以外的語言編寫的伺服器中。

  3. TorchScript 為我們提供了一種表示式,我們可以在程式碼上進行編譯器最佳化,以提供更有效率的執行。

  4. TorchScript 允許我們與許多後端/裝置執行時間進行介面連接,這些執行時間需要比單個運算符更廣泛的程式檢視。

我們可以發現,呼叫 traced_cell 產生的結果與 Python 模組相同。

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

使用腳本 (Scripting) 來轉換模組

我們使用模組的第二個版本,而不是包含控制流子模組的版本,是有原因的。 讓我們現在檢視一下。

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:263: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)

查看 .code 輸出,我們可以發現 if-else 分支無處可尋! 為什麼? 追蹤完全按照我們所說的那樣:執行程式碼,記錄發生的操作,並建構一個完全按照該操作執行的 ScriptModule。 不幸的是,像是控制流之類的東西被消除了。

我們如何才能在 TorchScript 中忠實地表示這個模組? 我們提供了一個腳本編譯器,它可以直接分析您的 Python 原始程式碼,並將其轉換為 TorchScript。 讓我們使用腳本編譯器來轉換 MyDecisionGate

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

萬歲! 我們現在已經在 TorchScript 中忠實地捕捉了我們程式的行為。 現在讓我們嘗試執行該程式。

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

混合使用腳本和追蹤

在某些情況下,需要使用追蹤而不是腳本 (例如,一個模組有許多架構決策是基於常數 Python 值做出的,而我們不希望這些值出現在 TorchScript 中)。 在這種情況下,腳本可以與追蹤組合:torch.jit.script 將嵌入追蹤模組的程式碼,而追蹤將嵌入腳本模組的程式碼。

第一個案例的範例

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

以及第二個案例的範例

class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

透過這種方式,可以在需要的情況下使用腳本和追蹤,並將它們一起使用。

儲存和載入模型

我們提供 API 來將 TorchScript 模組以封存格式儲存到磁碟和從磁碟載入。 這種格式包含程式碼、參數、屬性和偵錯資訊,這意味著該封存是模型的獨立表示式,可以載入到完全獨立的程序中。 讓我們儲存和載入我們封裝的 RNN 模組。

traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)
RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

如您所見,序列化保留了模組層次結構和我們一直在檢視的程式碼。 該模型也可以載入,例如,載入到 C++ 中以進行無 Python 執行。

延伸閱讀

我們已經完成了本教學課程! 如需更複雜的示範,請查看 NeurIPS 示範,瞭解如何使用 TorchScript 轉換機器翻譯模型:https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ

腳本的總執行時間: (0 分鐘 0.217 秒)

由 Sphinx-Gallery 產生之圖片集

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源