快捷方式

TorchScript 語言參考

TorchScript 是 Python 的一個靜態類型子集,可以直接編寫(使用 @torch.jit.script 裝飾器),或者通過追蹤從 Python 代碼自動生成。當使用追蹤時,代碼會自動轉換為這個 Python 子集,方法是只記錄 tensors 上的實際運算符,並簡單地執行和丟棄其他周圍的 Python 代碼。

當直接使用 @torch.jit.script 裝飾器編寫 TorchScript 時,程式設計師必須只使用 TorchScript 中支援的 Python 子集。本節記錄了 TorchScript 中支援的內容,就像它是獨立語言的語言參考一樣。本參考中未提及的任何 Python 功能都不是 TorchScript 的一部分。有關可用的 PyTorch tensor 方法、模組和函式的完整參考,請參閱內建函式

作為 Python 的一個子集,任何有效的 TorchScript 函式也是一個有效的 Python 函式。這使得停用 TorchScript並使用標準 Python 工具(如 pdb)偵錯函式成為可能。反之則不然:有很多有效的 Python 程式不是有效的 TorchScript 程式。相反,TorchScript 特別關注 Python 中表示 PyTorch 中神經網路模型所需的功能。

類型

TorchScript 和完整 Python 語言之間最大的區別在於,TorchScript 僅支援一小部分表示神經網路模型所需的類型。特別是,TorchScript 支援

類型

描述

Tensor

任何 dtype、維度或後端的 PyTorch tensor

Tuple[T0, T1, ..., TN]

一個包含子類型 T0T1 等的 tuple(例如 Tuple[Tensor, Tensor]

bool

布林值

int

純量整數

float

純量浮點數

str

字串

List[T]

一個所有成員都是 T 類型的 list

Optional[T]

一個可以是 None 或 T 類型的值

Dict[K, V]

一個具有鍵類型 K 和值類型 V 的 dict。僅允許 strintfloat 作為鍵類型。

T

一個 TorchScript 類別

E

一個 TorchScript 列舉

NamedTuple[T0, T1, ...]

一個 collections.namedtuple tuple 類型

Union[T0, T1, ...]

子類型 T0T1 等之一。

與 Python 不同,TorchScript 函式中的每個變數都必須具有單一靜態類型。這使得優化 TorchScript 函式更容易。

範例(類型不匹配)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...

不支援的類型構造

TorchScript 不支援 typing 模組的所有功能和類型。其中一些是更基本的事情,未來不太可能添加,而另一些則可能會添加,如果使用者有足夠的需求使其成為優先事項。

來自 typing 模組的這些類型和功能在 TorchScript 中不可用。

項目

描述

typing.Any

typing.Any 目前正在開發中,但尚未發布

typing.NoReturn

尚未實作

typing.Sequence

尚未實作

typing.Callable

尚未實作

typing.Literal

尚未實作

typing.ClassVar

尚未實作

typing.Final

這支援 模組屬性 類別屬性註解,但不支援函式

typing.AnyStr

TorchScript 不支援 bytes,因此未使用此類型

typing.overload

typing.overload 目前正在開發中,但尚未發布

類型別名

尚未實作

名義與結構子類型

名義類型正在開發中,但結構類型則不然

NewType

不太可能實作

泛型

不太可能實作

本文件中未明確列出的 typing 模組中的任何其他功能都不受支援。

預設類型

預設情況下,TorchScript 函式的所有參數都假定為 Tensor。要指定 TorchScript 函式的參數是另一種類型,可以使用 MyPy 樣式的類型註解,使用上面列出的類型。

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用 typing 模組中的 Python 3 類型提示來註解類型。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

空列表假定為 List[Tensor],而空 dict 則為 Dict[str, Tensor]。要實例化其他類型的空列表或 dict,請使用Python 3 類型提示

範例(Python 3 的類型註解)

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

可選類型改進

當在 if 語句的條件中進行與 None 的比較,或在 assert 中進行檢查時,TorchScript 將細化類型為 Optional[T] 的變數的類型。編譯器可以推斷多個使用 andornot 組合的 None 檢查。對於未明確編寫的 if 語句的 else 區塊,也會發生類型細化。

None 檢查必須位於 if 語句的條件中;將 None 檢查賦值給變數並在 if 語句的條件中使用它將不會細化檢查中變數的類型。只有區域變數會被細化,像是 self.x 這樣的屬性則不會,必須將其賦值給區域變數才能進行細化。

範例 (細化參數和區域變數的類型)

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 類別

警告

TorchScript 類別支援是實驗性的。目前,它最適合簡單的紀錄類型 (可以將其視為附加方法的 NamedTuple)。

如果 Python 類別使用 @torch.jit.script 進行註解,則可以在 TorchScript 中使用它們,就像您宣告 TorchScript 函數一樣。

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

此子集受到限制

  • 所有函數都必須是有效的 TorchScript 函數 (包括 __init__())。

  • 類別必須是新型類別,因為我們使用 __new__() 與 pybind11 構造它們。

  • TorchScript 類別是靜態類型。成員只能通過在 __init__() 方法中賦值給 self 來宣告。

    例如,在 __init__() 方法之外賦值給 self

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    將導致

    RuntimeError:
    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
  • 除了方法定義外,類別主體中不允許任何表達式。

  • 不支援繼承或任何其他多型策略,除了從 object 繼承以指定新型類別。

定義類別後,它可以像任何其他 TorchScript 類型一樣在 TorchScript 和 Python 中互換使用。

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 列舉 (Enums)

Python 列舉可以在 TorchScript 中使用,而無需任何額外的註解或程式碼。

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定義列舉後,它可以像任何其他 TorchScript 類型一樣在 TorchScript 和 Python 中互換使用。列舉值的類型必須是 intfloatstr。所有值必須是相同類型;不支援列舉值的異質類型。

具名元組 (Named Tuples)

collections.namedtuple 產生的類型可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

可迭代物件 (Iterables)

某些函數 (例如,zipenumerate) 只能對可迭代類型進行操作。TorchScript 中的可迭代類型包括 Tensor、列表、元組、字典、字串、torch.nn.ModuleListtorch.nn.ModuleDict

表達式

支援以下 Python 表達式。

字面量 (Literals)

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表建構 (List Construction)

假設空列表的類型為 List[Tensor]。其他列表字面量的類型是從成員的類型推導出來的。有關更多詳細信息,請參閱預設類型

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元組建構 (Tuple Construction)

(3, 4)
(3,)

字典建構 (Dict Construction)

假設空字典的類型為 Dict[str, Tensor]。其他字典字面量的類型是從成員的類型推導出來的。有關更多詳細信息,請參閱預設類型

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

變數 (Variables)

有關如何解析變數,請參閱變數解析

my_variable_name

算術運算符 (Arithmetic Operators)

a + b
a - b
a * b
a / b
a ^ b
a @ b

比較運算符 (Comparison Operators)

a == b
a != b
a < b
a > b
a <= b
a >= b

邏輯運算符 (Logical Operators)

a and b
a or b
not b

下標和切片 (Subscripts and Slicing)

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函數呼叫 (Function Calls)

呼叫內建函數

torch.rand(3, dtype=torch.int)

呼叫其他腳本函數

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法呼叫 (Method Calls)

呼叫內建類型 (如 tensor) 的方法:x.mm(y)

在模組上,方法必須先編譯才能被呼叫。TorchScript 編譯器在編譯其他方法時,會遞迴編譯它看到的方法。預設情況下,編譯從 forward 方法開始。forward 呼叫的任何方法都會被編譯,以及這些方法呼叫的任何方法,依此類推。若要從 forward 以外的方法開始編譯,請使用 @torch.jit.export 裝飾器 (decorator)(forward 會隱式地被標記為 @torch.jit.export)。

直接呼叫子模組(例如,self.resnet(input))等效於呼叫其 forward 方法(例如,self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表達式 (Ternary Expressions)

x if x > y else y

類型轉換 (Casts)

float(ten)
int(3.5)
bool(ten)
str(2)``

存取模組參數 (Accessing Module Parameters)

self.my_parameter
self.my_submodule.my_parameter

陳述式 (Statements)

TorchScript 支援下列類型的陳述式:

簡單賦值 (Simple Assignments)

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配賦值 (Pattern Matching Assignments)

a, b = tuple_or_list
a, b, *c = a_tuple

多重賦值 (Multiple Assignments)

a = b, c = tup

If 陳述式 (If Statements)

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除了布林值之外,浮點數、整數和張量 (Tensors) 也可以在條件式中使用,並會隱式轉換為布林值。

While 迴圈 (While Loops)

a = 0
while a < 4:
    print(a)
    a += 1

帶有 range 的 For 迴圈 (For loops with range)

x = 0
for i in range(10):
    x *= i

遍歷 tuple 的 For 迴圈 (For loops over tuples)

這些會展開迴圈,為 tuple 的每個成員生成一個主體 (body)。對於每個成員,主體必須正確進行類型檢查 (type-check)。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

遍歷常數 nn.ModuleList 的 For 迴圈 (For loops over constant nn.ModuleList)

要在編譯方法中使用 nn.ModuleList,必須將該屬性的名稱添加到類型的 __constants__ 列表中,將其標記為常數。遍歷 nn.ModuleList 的 For 迴圈將在編譯時展開迴圈的主體,其中包含常數模組列表的每個成員。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

Break 和 Continue

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

Return

return a, b

變數解析 (Variable Resolution)

TorchScript 支援 Python 變數解析(即作用域)規則的子集。局部變數的行為與 Python 中相同,但有一個限制,即變數在函數的所有路徑上必須具有相同的類型。如果一個變數在 if 陳述式的不同分支上具有不同的類型,則在 if 陳述式結束後使用它會產生錯誤。

同樣地,如果一個變數僅在函數的某些路徑上被定義,則不允許使用該變數。

範例

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

非局部變數在函數定義時解析為 Python 值。然後,使用 使用 Python 值 中描述的規則將這些值轉換為 TorchScript 值。

使用 Python 值 (Use of Python Values)

為了使編寫 TorchScript 更方便,我們允許腳本程式碼引用周圍範圍內的 Python 值。例如,任何時候引用 torch,TorchScript 編譯器實際上是在宣告函數時將其解析為 torch Python 模組。這些 Python 值不是 TorchScript 的第一類部分。相反,它們在編譯時被簡化為 TorchScript 支援的基本類型。這取決於編譯時引用的 Python 值的動態類型。本節描述了在 TorchScript 中存取 Python 值時使用的規則。

函數 (Functions)

TorchScript 可以呼叫 Python 函數。在將模型逐步轉換為 TorchScript 時,此功能非常有用。可以將模型逐函數移動到 TorchScript,並保留對 Python 函數的呼叫。這樣,您可以逐步檢查模型的正確性。

torch.jit.is_scripting()[source][source]

此函數在編譯時返回 True,否則返回 False。這在使用 @unused 裝飾器時特別有用,可以在您的模型中保留尚未與 TorchScript 相容的程式碼。 .. testcode

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
    if torch.jit.is_scripting():
        return torch.linear(x)
    else:
        return unsupported_linear_op(x)
返回類型

bool

torch.jit.is_tracing()[source][source]

返回一個布林值。

如果在追蹤過程中(如果在透過 torch.jit.trace 追蹤程式碼時呼叫函數),則返回 True,否則返回 False

在 Python 模組上查找屬性 (Attribute Lookup On Python Modules)

TorchScript 可以在模組上查找屬性。像 torch.add 這樣的 內建函數 就是透過這種方式存取的。這允許 TorchScript 呼叫在其他模組中定義的函數。

Python 定義的常數 (Python-defined Constants)

TorchScript 還提供了一種使用 Python 中定義的常數的方法。這些可用於將超參數硬編碼到函數中,或定義通用常數。有兩種方法可以指定應將 Python 值視為常數。

  1. 作為模組的屬性查找的值被假定為常數

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. ScriptModule 的屬性可以使用 Final[T] 進行註釋,以將其標記為常數

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支援的常數 Python 類型包括:

  • int

  • float

  • bool

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支援類型的 tuple

  • torch.nn.ModuleList,可以在 TorchScript for 迴圈中使用

模組屬性 (Module Attributes)

可以使用 torch.nn.Parameter 封裝器和 register_buffer 將 tensors 指派給模組。如果類型可以被推斷,則編譯後指派給模組的其他值將被新增到編譯後的模組中。TorchScript 中可用的所有 類型 都可以用作模組屬性。Tensor 屬性在語義上與 buffers 相同。空列表和字典以及 None 值的類型無法被推斷,必須透過 PEP 526 風格 的類別註解來指定。如果類型無法被推斷並且沒有明確註解,它將不會作為屬性新增到產生的 ScriptModule 中。

範例

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

文件

取得 PyTorch 的完整開發者文件

查看文件

教學課程

取得適合初學者和進階開發人員的深度教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源