捷徑

TorchScript

TorchScript 是一種從 PyTorch 程式碼建立可序列化和可優化模型的方法。任何 TorchScript 程式都可以從 Python 程序儲存,並在沒有 Python 依賴性的程序中載入。

我們提供工具來逐步將模型從純 Python 程式轉換為可以獨立於 Python 執行的 TorchScript 程式,例如在獨立的 C++ 程式中。這使得可以在 Python 中使用熟悉的工具在 PyTorch 中訓練模型,然後透過 TorchScript 將模型匯出到生產環境,在該環境中,Python 程式可能由於效能和多線程原因而不利。

如需 TorchScript 的簡要介紹,請參閱TorchScript 簡介教學。

如需將 PyTorch 模型轉換為 TorchScript 並在 C++ 中執行的端到端範例,請參閱在 C++ 中載入 PyTorch 模型教學。

建立 TorchScript 程式碼

script

編寫函數腳本。

trace

追蹤函數並返回一個可執行檔或 ScriptFunction,該函數將使用即時編譯進行優化。

script_if_tracing

在追蹤期間首次呼叫 fn 時編譯它。

trace_module

追蹤模組並返回一個可執行檔 ScriptModule,該模組將使用即時編譯進行優化。

fork

建立一個執行 func 的非同步任務,以及對該執行結果值的引用。

wait

強制完成 torch.jit.Future[T] 非同步任務,並返回任務的結果。

ScriptModule

C++ torch::jit::Module 的包裝器,具有方法、屬性和參數。

ScriptFunction

功能上等同於 ScriptModule,但表示單個函數,並且沒有任何屬性或參數。

freeze

凍結 ScriptModule,內聯子模組和屬性作為常量。

optimize_for_inference

執行一組優化傳遞,以優化模型的推理目的。

enable_onednn_fusion

根據參數 enabled 啟用或停用 onednn JIT 融合。

onednn_fusion_enabled

返回是否啟用 onednn JIT 融合。

set_fusion_strategy

設定融合期間可能發生的特殊化類型和數量。

strict_fusion

如果在推論中並非所有節點都已融合,或者在訓練中進行了符號微分,則會產生錯誤。

save

儲存此模組的離線版本,以便在單獨的程序中使用。

load

載入先前使用 torch.jit.save 儲存的 ScriptModuleScriptFunction

ignore

此裝飾器向編譯器指示應忽略函數或方法,並將其保留為 Python 函數。

unused

此裝飾器向編譯器指示應忽略函數或方法,並將其替換為引發異常。

interface

裝飾以註解不同類型的類別或模組。

isinstance

在 TorchScript 中提供容器類型細化。

Attribute

此方法是一個傳遞函數,返回 value,主要用於向 TorchScript 編譯器指示左側表達式是類型為 type 的類別實例屬性。

annotate

用於在 TorchScript 編譯器中指定 the_value 的類型。

混合追蹤和腳本編寫

在許多情況下,追蹤或腳本編寫都是將模型轉換為 TorchScript 的更簡單方法。 可以組合追蹤和腳本編寫,以滿足模型某部分的特定要求。

編寫腳本的函數可以呼叫追蹤的函數。 當您需要在簡單的前饋模型周圍使用控制流程時,這特別有用。 例如,序列到序列模型的波束搜索通常用腳本編寫,但可以呼叫使用追蹤生成的編碼器模組。

範例(在腳本中呼叫追蹤的函數)

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

追蹤的函數可以呼叫腳本函數。 當模型的一小部分需要一些控制流程時,即使模型的大部分只是前饋網絡,這也很有用。 在追蹤函數呼叫的腳本函數內部的控制流程會正確保留。

範例(在追蹤函數中呼叫腳本函數)

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

此組合也適用於 nn.Module,在這種情況下,它可用於使用追蹤生成一個子模組,該子模組可以從腳本模組的方法中呼叫。

範例(使用追蹤的模組)

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript 語言

TorchScript 是 Python 的靜態類型子集,因此許多 Python 功能直接適用於 TorchScript。 有關詳細信息,請參閱完整的TorchScript 語言參考

內建函數和模組

TorchScript 支援使用大多數 PyTorch 函數和許多 Python 內建函數。 有關支援函數的完整參考,請參閱TorchScript 內建函數

PyTorch 函數和模組

TorchScript 支援 PyTorch 提供的張量和神經網絡函數的子集。 Tensor 上的大多數方法以及 torch 命名空間中的函數、torch.nn.functional 中的所有函數以及 torch.nn 中的大多數模組都受到 TorchScript 的支援。

有關不支援的 PyTorch 函數和模組的列表,請參閱TorchScript 不支援的 PyTorch 結構

Python 函數和模組

TorchScript 支援許多 Python 的 內建函數math 模組也受支援(有關詳細信息,請參閱math 模組),但不支援任何其他 Python 模組(內建或第三方)。

Python 語言參考比較

有關支援的 Python 功能的完整列表,請參閱Python 語言參考覆蓋範圍

除錯

停用 JIT 以進行除錯

PYTORCH_JIT

設定環境變數 PYTORCH_JIT=0 將會停用所有 script 和 tracing 的註解。如果您的 TorchScript 模型中出現難以除錯的錯誤,您可以使用此標記強制所有內容都使用原生 Python 執行。由於使用此標記會停用 TorchScript (腳本和追蹤),您可以使用 pdb 等工具來偵錯模型程式碼。例如:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

使用 pdb 偵錯此腳本可以正常運作,除了當我們調用 @torch.jit.script 函數時。我們可以全域停用 JIT,這樣我們就可以將 @torch.jit.script 函數作為一般的 Python 函數呼叫,而不是編譯它。如果上述腳本名為 disable_jit_example.py,我們可以這樣調用它:

$ PYTORCH_JIT=0 python disable_jit_example.py

然後我們就能夠像一般的 Python 函數一樣逐步進入 @torch.jit.script 函數。要為特定函數停用 TorchScript 編譯器,請參閱 @torch.jit.ignore

檢查程式碼

TorchScript 為所有 ScriptModule 實例提供了一個程式碼美化列印器。這個美化列印器將 script 方法的程式碼解釋為有效的 Python 語法。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

具有單個 forward 方法的 ScriptModule 將具有一個 code 屬性,您可以使用它來檢查 ScriptModule 的程式碼。如果 ScriptModule 有多個方法,您需要在該方法本身而不是模組上存取 .code。我們可以通過存取 .foo.code 來檢查 ScriptModule 上名為 foo 的方法的程式碼。上面的例子產生以下輸出:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

這是 TorchScript 編譯的 forward 方法的程式碼。您可以使用它來確保 TorchScript (追蹤或腳本) 正確捕獲了您的模型程式碼。

解讀圖

TorchScript 在比程式碼美化列印器更低的層次上也有一個表示,即 IR 圖的形式。

TorchScript 使用靜態單一賦值 (SSA) 中間表示 (IR) 來表示計算。這種格式的指令由 ATen (PyTorch 的 C++ 後端) 運算符和其他原始運算符組成,包括用於循環和條件的控制流運算符。舉例來說:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph 遵循 檢查程式碼 章節中關於 forward 方法查詢的相同規則。

上面的例子腳本產生以下圖:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

以指令 %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 為例。

  • %rv.1 : Tensor 意味著我們將輸出分配給一個名為 rv.1 的(唯一)值,該值的類型為 Tensor,並且我們不知道它的具體形狀。

  • aten::zeros 是運算符 (相當於 torch.zeros),輸入列表 (%4, %6, %6, %10, %12) 指定應將範圍內的哪些值作為輸入傳遞。可以在 內建函數 中找到內建函數 (如 aten::zeros) 的 schema。

  • # test.py:9:10 是原始源文件中生成此指令的位置。在這種情況下,它是一個名為 test.py 的文件,位於第 9 行,字符 10。

請注意,運算符也可以具有相關聯的 blocks,即 prim::Loopprim::If 運算符。在圖的列印輸出中,這些運算符的格式被格式化為反映它們等效的原始碼形式,以方便進行除錯。

可以檢查圖表,以確認 ScriptModule 描述的計算是正確的,無論是以自動或手動方式,如下所述。

Tracer

追蹤邊緣情況

存在一些邊緣情況,其中給定 Python 函數/模組的追蹤不能代表底層程式碼。這些情況可能包括:

  • 追蹤依賴於輸入(例如張量形狀)的控制流

  • 追蹤 tensor 視圖的原地 (in-place) 操作 (例如,在賦值運算式的左側進行索引)

請注意,未來這些情況實際上可能可以追蹤。

自動追蹤檢查

自動捕捉追蹤中許多錯誤的一種方法是在 torch.jit.trace() API 上使用 check_inputscheck_inputs 接受一個輸入元組列表,這些元組將用於重新追蹤計算並驗證結果。例如:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

給我們以下診斷資訊:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

此訊息表明,我們第一次追蹤它時與使用 check_inputs 追蹤它時,計算結果不同。 實際上,loop_in_traced_fn 主體中的迴圈取決於輸入 x 的形狀,因此當我們嘗試使用不同形狀的另一個 x 時,追蹤結果會有所不同。

在這種情況下,可以使用 torch.jit.script() 來捕獲這種依賴於資料的控制流程

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))

產生:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

追蹤器警告

追蹤器會針對追蹤計算中的幾個有問題的模式產生警告。 例如,追蹤一個包含 Tensor 切片(視圖)上的原地賦值的函式:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

產生多個警告和一個僅返回輸入的圖形:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我們可以透過修改程式碼以不使用原地更新來修復此問題,而是使用 torch.cat 以異地 (out-of-place) 方式建立結果 tensor:

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常見問題

問:我想在 GPU 上訓練模型,並在 CPU 上進行推論。 最佳做法是什麼?

首先將您的模型從 GPU 轉換為 CPU,然後儲存它,如下所示:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

建議這樣做,因為追蹤器可能會觀察到在特定裝置上建立 tensor,因此轉換已載入的模型可能會產生意想不到的影響。 在儲存模型*之前*轉換模型可確保追蹤器具有正確的裝置資訊。

問:我該如何在 ScriptModule 上儲存屬性?

假設我們有一個像這樣的模型:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

如果實例化 Model,則會導致編譯錯誤,因為編譯器不知道 x。 有 4 種方法可以告知編譯器關於 ScriptModule 上的屬性:

1. nn.Parameter - 包裹在 nn.Parameter 中的值將像在 nn.Module 中一樣運作

2. register_buffer - 包裹在 register_buffer 中的值將像在 nn.Module 中一樣運作。 這等效於 Tensor 類型的屬性(請參閱 4)。

3. 常數 - 將類別成員註解為 Final(或將其添加到類別定義層級上名為 __constants__ 的列表中)會將包含的名稱標記為常數。 常數直接儲存在模型的程式碼中。 有關詳細資訊,請參閱builtin-constants

4. 屬性 - 可以將作為supported type 的值新增為可變屬性。 大多數類型都可以推斷,但有些可能需要指定,有關詳細資訊,請參閱module attributes

問:我想追蹤模組的方法,但我一直收到此錯誤

RuntimeError: 無法 需要 梯度 Tensor 插入 常數。 考慮 將其 作為 參數 輸入, 分離 梯度

此錯誤通常表示您正在追蹤的方法使用模組的參數,並且您正在傳遞模組的方法而不是模組實例(例如,my_module_instance.forwardmy_module_instance)。

  • 使用模組的方法調用 trace 會將模組參數(可能需要梯度)捕獲為**常數**。

  • 另一方面,使用模組的實例(例如,my_module)調用 trace 會建立一個新模組,並正確地將參數複製到新模組中,因此如果需要,它們可以累積梯度。

若要追蹤模組上的特定方法,請參閱 torch.jit.trace_module

已知問題

如果您將 Sequential 與 TorchScript 搭配使用,則某些 Sequential 子模組的輸入可能會被錯誤地推斷為 Tensor,即使它們以其他方式進行註解。 標準的解決方案是子類化 nn.Sequential 並使用正確輸入類型重新宣告 forward

附錄

遷移到 PyTorch 1.2 遞迴腳本 API

本節詳細介紹了 PyTorch 1.2 中 TorchScript 的變更。 如果您不熟悉 TorchScript,則可以跳過本節。 PyTorch 1.2 的 TorchScript API 有兩個主要變更。

1. torch.jit.script 現在將嘗試遞迴編譯它遇到的函數、方法和類別。 呼叫 torch.jit.script 後,編譯是“選擇退出 (opt-out)”,而不是“選擇加入 (opt-in)”。

2. 現在建議使用 torch.jit.script(nn_module_instance) 來建立 ScriptModule,而不是繼承自 torch.jit.ScriptModule。這些變更結合起來,提供了一個更簡單、更容易使用的 API,可將您的 nn.Module 轉換為 ScriptModule,以便在非 Python 環境中進行最佳化和執行。

新的使用方式如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • 模組的 forward 預設會被編譯。從 forward 呼叫的方法會按照它們在 forward 中被使用的順序延遲編譯。

  • 若要編譯 forward 以外,且沒有從 forward 呼叫的方法,請加入 @torch.jit.export

  • 若要停止編譯器編譯方法,請加入 @torch.jit.ignore@torch.jit.unused@ignore 會將

  • 方法保留為對 Python 的呼叫,而 @unused 則會將其替換為例外。@ignored 無法匯出;@unused 則可以。

  • 大多數屬性類型都可以推斷出來,因此不需要 torch.jit.Attribute。對於空的容器類型,請使用 PEP 526 樣式的類別註釋來標註它們的類型。

  • 可以使用 Final 類別註釋來標記常數,而無需將成員名稱新增到 __constants__ 中。

  • 可以使用 Python 3 類型提示來代替 torch.jit.annotate

由於這些變更,以下項目被視為已棄用,不應出現在新程式碼中
  • @torch.jit.script_method 裝飾器

  • 繼承自 torch.jit.ScriptModule 的類別

  • torch.jit.Attribute 包裝類別

  • __constants__ 陣列

  • torch.jit.annotate 函數

模組

警告

@torch.jit.ignore 註解的行為在 PyTorch 1.2 中發生了變更。在 PyTorch 1.2 之前,@ignore 裝飾器用於使函數或方法可從匯出的程式碼中呼叫。若要恢復此功能,請使用 @torch.jit.unused()@torch.jit.ignore 現在等同於 @torch.jit.ignore(drop=False)。有關詳細資訊,請參閱 @torch.jit.ignore@torch.jit.unused

當傳遞給 torch.jit.script 函數時,torch.nn.Module 的資料會複製到 ScriptModule,並且 TorchScript 編譯器會編譯該模組。模組的 forward 預設會被編譯。從 forward 呼叫的方法會按照它們在 forward 中被使用的順序延遲編譯,以及任何 @torch.jit.export 方法。

torch.jit.export(fn)[來源][來源]

此裝飾器指示 nn.Module 上的方法用作 ScriptModule 的進入點,並且應該被編譯。

forward 隱含地假定為一個進入點,因此不需要此裝飾器。從 forward 呼叫的函數和方法會在編譯器看到它們時被編譯,因此它們也不需要此裝飾器。

範例(在方法上使用 @torch.jit.export

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

函數

函數沒有太多變化,如有需要,可以使用 @torch.jit.ignoretorch.jit.unused 進行裝飾。

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 類別

警告

TorchScript 類別支援是實驗性的。目前,它最適合簡單的記錄類型(想想附加了方法的 NamedTuple)。

在使用者定義的 TorchScript Class 中的所有內容預設都會匯出,如果需要,可以使用 @torch.jit.ignore 來裝飾函式。

屬性 (Attributes)

TorchScript 編譯器需要知道 模組屬性 (module attributes) 的類型。大多數類型可以從成員的值推斷出來。空的 list 和 dict 無法推斷出它們的類型,因此必須使用 PEP 526 風格的類別註解來標註其類型。如果類型無法推斷,且沒有明確註解,它將不會作為屬性新增到最終的 ScriptModule 中。

舊的 API

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

新的 API

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

常數 (Constants)

Final 類型建構子可以用於將成員標記為 常數 (constant)。如果成員沒有標記為常數,它們將作為屬性複製到最終的 ScriptModule 中。如果已知該值是固定的,則使用 Final 可以開啟優化的機會,並提供額外的類型安全。

舊的 API

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

新的 API

from typing import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

變數 (Variables)

容器預設的類型為 Tensor 且為 non-optional (更多資訊請參考 預設類型 (Default Types))。先前,torch.jit.annotate 用於告知 TorchScript 編譯器應使用的類型。現在支援 Python 3 風格的類型提示。

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

融合後端 (Fusion Backends)

有幾個融合後端可用於優化 TorchScript 的執行。CPU 上的預設融合器是 NNC,它可以為 CPU 和 GPU 執行融合。GPU 上的預設融合器是 NVFuser,它支援更廣泛的運算子,並且已經展示了具有改進吞吐量的生成核心。有關使用和除錯的更多詳細資訊,請參閱 NVFuser 文件

文件

存取 PyTorch 的完整開發人員文件

檢視文件 (View Docs)

教學 (Tutorials)

取得初學者和進階開發人員的深入教學

檢視教學 (View Tutorials)

資源 (Resources)

尋找開發資源並獲得您的問題解答

檢視資源 (View Resources)