捷徑

torch.jit.script

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)[原始碼][原始碼]

將函式 Script 化。

Script 化函式或 nn.Module 會檢測原始碼,使用 TorchScript 編譯器將其編譯為 TorchScript 程式碼,並傳回 ScriptModuleScriptFunction。TorchScript 本身是 Python 語言的子集,因此並非 Python 中的所有功能都適用,但我們提供了足夠的功能來計算張量並執行控制相關操作。如需完整指南,請參閱 TorchScript 語言參考

Script 化字典或列表會將其中的資料複製到 TorchScript 實例中,然後可以透過參照在 Python 和 TorchScript 之間傳遞,而無需複製額外負擔。

torch.jit.script 可以用作模組、函式、字典和列表的函式

並用作 TorchScript 類別 和函式的裝飾器 @torch.jit.script

參數
  • obj (Callable, class, or nn.Module) – 要編譯的 nn.Module、函式、類別類型、字典或列表。

  • example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供範例輸入來註解函式或 nn.Module 的引數。

傳回

如果 objnn.Module,則 script 會傳回 ScriptModule 物件。傳回的 ScriptModule 將具有與原始 nn.Module 相同的子模組和參數集。如果 obj 是獨立函式,則會傳回 ScriptFunction。如果 objdict,則 script 會傳回 torch._C.ScriptDict 的實例。如果 objlist,則 script 會傳回 torch._C.ScriptList 的實例。

Script 化函式

@torch.jit.script 裝飾器將透過編譯函式的主體來建構 ScriptFunction

範例(Script 化函式)

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
**使用 example_inputs Script 化函式

可以使用範例輸入來註解函式引數。

範例(在 Script 化之前註解函式)

import torch

def test_sum(a, b):
    return a + b

# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
Script 化 nn.Module

預設情況下,Script 化 nn.Module 會編譯 forward 方法,並遞迴編譯 forward 呼叫的任何方法、子模組和函式。如果 nn.Module 僅使用 TorchScript 中支援的功能,則無需變更原始模組程式碼。script 將建構 ScriptModule,該模組具有原始模組的屬性、參數和方法的副本。

範例(使用參數 Script 化簡單模組)

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

範例(使用追蹤的子模組 Script 化模組)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

若要編譯 forward 以外的方法(並遞迴編譯它呼叫的任何內容),請將 @torch.jit.export 裝飾器新增至該方法。若要選擇不編譯,請使用 @torch.jit.ignore@torch.jit.unused

範例(模組中匯出和忽略的方法)

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

範例(使用 example_inputs 註解 nn.Module 的 forward)

import torch
import torch.nn as nn
from typing import NamedTuple

class MyModule(NamedTuple):
result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a)
        return result

pdt_model = TestNNModule()

# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# Run the scripted_model with actual inputs
print(scripted_model([20]))

文件

存取 PyTorch 的全面開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得問題的解答

檢視資源