TorchScript 語言參考¶
TorchScript 是 Python 的一個靜態類型子集,可以直接編寫(使用 @torch.jit.script
裝飾器),或者通過追蹤從 Python 代碼自動生成。當使用追蹤時,代碼會自動轉換為這個 Python 子集,方法是只記錄 tensors 上的實際運算符,並簡單地執行和丟棄其他周圍的 Python 代碼。
當直接使用 @torch.jit.script
裝飾器編寫 TorchScript 時,程式設計師必須只使用 TorchScript 中支援的 Python 子集。本節記錄了 TorchScript 中支援的內容,就像它是獨立語言的語言參考一樣。本參考中未提及的任何 Python 功能都不是 TorchScript 的一部分。有關可用的 PyTorch tensor 方法、模組和函式的完整參考,請參閱內建函式。
作為 Python 的一個子集,任何有效的 TorchScript 函式也是一個有效的 Python 函式。這使得停用 TorchScript並使用標準 Python 工具(如 pdb
)偵錯函式成為可能。反之則不然:有很多有效的 Python 程式不是有效的 TorchScript 程式。相反,TorchScript 特別關注 Python 中表示 PyTorch 中神經網路模型所需的功能。
類型¶
TorchScript 和完整 Python 語言之間最大的區別在於,TorchScript 僅支援一小部分表示神經網路模型所需的類型。特別是,TorchScript 支援
類型 |
描述 |
---|---|
|
任何 dtype、維度或後端的 PyTorch tensor |
|
一個包含子類型 |
|
布林值 |
|
純量整數 |
|
純量浮點數 |
|
字串 |
|
一個所有成員都是 |
|
一個可以是 None 或 |
|
一個具有鍵類型 |
|
|
|
|
|
一個 |
|
子類型 |
與 Python 不同,TorchScript 函式中的每個變數都必須具有單一靜態類型。這使得優化 TorchScript 函式更容易。
範例(類型不匹配)
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
Traceback (most recent call last):
...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
if x:
~~~~~
r = torch.rand(1)
~~~~~~~~~~~~~~~~~
else:
~~~~~
r = 4
~~~~~ <--- HERE
return r
and was used here:
else:
r = 4
return r
~ <--- HERE...
不支援的類型構造¶
TorchScript 不支援 typing
模組的所有功能和類型。其中一些是更基本的事情,未來不太可能添加,而另一些則可能會添加,如果使用者有足夠的需求使其成為優先事項。
來自 typing
模組的這些類型和功能在 TorchScript 中不可用。
項目 |
描述 |
---|---|
|
|
尚未實作 |
|
尚未實作 |
|
尚未實作 |
|
尚未實作 |
|
尚未實作 |
|
這支援 模組屬性 類別屬性註解,但不支援函式 |
|
TorchScript 不支援 |
|
|
|
類型別名 |
尚未實作 |
名義與結構子類型 |
名義類型正在開發中,但結構類型則不然 |
NewType |
不太可能實作 |
泛型 |
不太可能實作 |
本文件中未明確列出的 typing
模組中的任何其他功能都不受支援。
預設類型¶
預設情況下,TorchScript 函式的所有參數都假定為 Tensor。要指定 TorchScript 函式的參數是另一種類型,可以使用 MyPy 樣式的類型註解,使用上面列出的類型。
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
注意
也可以使用 typing
模組中的 Python 3 類型提示來註解類型。
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
空列表假定為 List[Tensor]
,而空 dict 則為 Dict[str, Tensor]
。要實例化其他類型的空列表或 dict,請使用Python 3 類型提示。
範例(Python 3 的類型註解)
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: List[Tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: Dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
可選類型改進¶
當在 if 語句的條件中進行與 None
的比較,或在 assert
中進行檢查時,TorchScript 將細化類型為 Optional[T]
的變數的類型。編譯器可以推斷多個使用 and
、or
和 not
組合的 None
檢查。對於未明確編寫的 if 語句的 else 區塊,也會發生類型細化。
None
檢查必須位於 if 語句的條件中;將 None
檢查賦值給變數並在 if 語句的條件中使用它將不會細化檢查中變數的類型。只有區域變數會被細化,像是 self.x
這樣的屬性則不會,必須將其賦值給區域變數才能進行細化。
範例 (細化參數和區域變數的類型)
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super().__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
TorchScript 類別¶
警告
TorchScript 類別支援是實驗性的。目前,它最適合簡單的紀錄類型 (可以將其視為附加方法的 NamedTuple
)。
如果 Python 類別使用 @torch.jit.script
進行註解,則可以在 TorchScript 中使用它們,就像您宣告 TorchScript 函數一樣。
@torch.jit.script
class Foo:
def __init__(self, x, y):
self.x = x
def aug_add_x(self, inc):
self.x += inc
此子集受到限制
所有函數都必須是有效的 TorchScript 函數 (包括
__init__()
)。類別必須是新型類別,因為我們使用
__new__()
與 pybind11 構造它們。TorchScript 類別是靜態類型。成員只能通過在
__init__()
方法中賦值給 self 來宣告。例如,在
__init__()
方法之外賦值給self
@torch.jit.script class Foo: def assign_x(self): self.x = torch.rand(2, 3)
將導致
RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
除了方法定義外,類別主體中不允許任何表達式。
不支援繼承或任何其他多型策略,除了從
object
繼承以指定新型類別。
定義類別後,它可以像任何其他 TorchScript 類型一樣在 TorchScript 和 Python 中互換使用。
# Declare a TorchScript class
@torch.jit.script
class Pair:
def __init__(self, first, second):
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type: (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
TorchScript 列舉 (Enums)¶
Python 列舉可以在 TorchScript 中使用,而無需任何額外的註解或程式碼。
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
定義列舉後,它可以像任何其他 TorchScript 類型一樣在 TorchScript 和 Python 中互換使用。列舉值的類型必須是 int
、float
或 str
。所有值必須是相同類型;不支援列舉值的異質類型。
具名元組 (Named Tuples)¶
由 collections.namedtuple
產生的類型可以在 TorchScript 中使用。
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
# type: (Point) -> Tensor
return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
可迭代物件 (Iterables)¶
某些函數 (例如,zip
和 enumerate
) 只能對可迭代類型進行操作。TorchScript 中的可迭代類型包括 Tensor
、列表、元組、字典、字串、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
表達式¶
支援以下 Python 表達式。
字面量 (Literals)¶
True
False
None
'string literals'
"string literals"
3 # interpreted as int
3.4 # interpreted as a float
列表建構 (List Construction)¶
假設空列表的類型為 List[Tensor]
。其他列表字面量的類型是從成員的類型推導出來的。有關更多詳細信息,請參閱預設類型。
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
元組建構 (Tuple Construction)¶
(3, 4)
(3,)
算術運算符 (Arithmetic Operators)¶
a + b
a - b
a * b
a / b
a ^ b
a @ b
比較運算符 (Comparison Operators)¶
a == b
a != b
a < b
a > b
a <= b
a >= b
邏輯運算符 (Logical Operators)¶
a and b
a or b
not b
下標和切片 (Subscripts and Slicing)¶
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
函數呼叫 (Function Calls)¶
呼叫內建函數
torch.rand(3, dtype=torch.int)
呼叫其他腳本函數
import torch
@torch.jit.script
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
return foo(x)
方法呼叫 (Method Calls)¶
呼叫內建類型 (如 tensor) 的方法:x.mm(y)
在模組上,方法必須先編譯才能被呼叫。TorchScript 編譯器在編譯其他方法時,會遞迴編譯它看到的方法。預設情況下,編譯從 forward
方法開始。forward
呼叫的任何方法都會被編譯,以及這些方法呼叫的任何方法,依此類推。若要從 forward
以外的方法開始編譯,請使用 @torch.jit.export
裝飾器 (decorator)(forward
會隱式地被標記為 @torch.jit.export
)。
直接呼叫子模組(例如,self.resnet(input)
)等效於呼叫其 forward
方法(例如,self.resnet.forward(input)
)。
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
def __init__(self):
super().__init__()
means = torch.tensor([103.939, 116.779, 123.68])
self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
resnet = torchvision.models.resnet18()
self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
def helper(self, input):
return self.resnet(input - self.means)
def forward(self, input):
return self.helper(input)
# Since nothing in the model calls `top_level_method`, the compiler
# must be explicitly told to compile this method
@torch.jit.export
def top_level_method(self, input):
return self.other_helper(input)
def other_helper(self, input):
return input + 10
# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
三元表達式 (Ternary Expressions)¶
x if x > y else y
類型轉換 (Casts)¶
float(ten)
int(3.5)
bool(ten)
str(2)``
存取模組參數 (Accessing Module Parameters)¶
self.my_parameter
self.my_submodule.my_parameter
陳述式 (Statements)¶
TorchScript 支援下列類型的陳述式:
簡單賦值 (Simple Assignments)¶
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
模式匹配賦值 (Pattern Matching Assignments)¶
a, b = tuple_or_list
a, b, *c = a_tuple
多重賦值 (Multiple Assignments)
a = b, c = tup
Print 陳述式 (Print Statements)¶
print("the result of an add:", a + b)
If 陳述式 (If Statements)¶
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
除了布林值之外,浮點數、整數和張量 (Tensors) 也可以在條件式中使用,並會隱式轉換為布林值。
While 迴圈 (While Loops)¶
a = 0
while a < 4:
print(a)
a += 1
帶有 range 的 For 迴圈 (For loops with range)¶
x = 0
for i in range(10):
x *= i
遍歷 tuple 的 For 迴圈 (For loops over tuples)¶
這些會展開迴圈,為 tuple 的每個成員生成一個主體 (body)。對於每個成員,主體必須正確進行類型檢查 (type-check)。
tup = (3, torch.rand(4))
for x in tup:
print(x)
遍歷常數 nn.ModuleList 的 For 迴圈 (For loops over constant nn.ModuleList)¶
要在編譯方法中使用 nn.ModuleList
,必須將該屬性的名稱添加到類型的 __constants__
列表中,將其標記為常數。遍歷 nn.ModuleList
的 For 迴圈將在編譯時展開迴圈的主體,其中包含常數模組列表的每個成員。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
__constants__ = ['mods']
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
m = torch.jit.script(MyModule())
Break 和 Continue¶
for i in range(5):
if i == 1:
continue
if i == 3:
break
print(i)
Return¶
return a, b
變數解析 (Variable Resolution)¶
TorchScript 支援 Python 變數解析(即作用域)規則的子集。局部變數的行為與 Python 中相同,但有一個限制,即變數在函數的所有路徑上必須具有相同的類型。如果一個變數在 if 陳述式的不同分支上具有不同的類型,則在 if 陳述式結束後使用它會產生錯誤。
同樣地,如果一個變數僅在函數的某些路徑上被定義,則不允許使用該變數。
範例
@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y)
Traceback (most recent call last):
...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
if x < 0:
~~~~~~~~~
y = 4
~~~~~ <--- HERE
print(y)
and was used here:
if x < 0:
y = 4
print(y)
~ <--- HERE...
非局部變數在函數定義時解析為 Python 值。然後,使用 使用 Python 值 中描述的規則將這些值轉換為 TorchScript 值。
使用 Python 值 (Use of Python Values)¶
為了使編寫 TorchScript 更方便,我們允許腳本程式碼引用周圍範圍內的 Python 值。例如,任何時候引用 torch
,TorchScript 編譯器實際上是在宣告函數時將其解析為 torch
Python 模組。這些 Python 值不是 TorchScript 的第一類部分。相反,它們在編譯時被簡化為 TorchScript 支援的基本類型。這取決於編譯時引用的 Python 值的動態類型。本節描述了在 TorchScript 中存取 Python 值時使用的規則。
函數 (Functions)¶
TorchScript 可以呼叫 Python 函數。在將模型逐步轉換為 TorchScript 時,此功能非常有用。可以將模型逐函數移動到 TorchScript,並保留對 Python 函數的呼叫。這樣,您可以逐步檢查模型的正確性。
- torch.jit.is_scripting()[source][source]¶
此函數在編譯時返回 True,否則返回 False。這在使用 @unused 裝飾器時特別有用,可以在您的模型中保留尚未與 TorchScript 相容的程式碼。 .. testcode
import torch @torch.jit.unused def unsupported_linear_op(x): return x def linear(x): if torch.jit.is_scripting(): return torch.linear(x) else: return unsupported_linear_op(x)
- 返回類型
在 Python 模組上查找屬性 (Attribute Lookup On Python Modules)¶
TorchScript 可以在模組上查找屬性。像 torch.add
這樣的 內建函數 就是透過這種方式存取的。這允許 TorchScript 呼叫在其他模組中定義的函數。
Python 定義的常數 (Python-defined Constants)¶
TorchScript 還提供了一種使用 Python 中定義的常數的方法。這些可用於將超參數硬編碼到函數中,或定義通用常數。有兩種方法可以指定應將 Python 值視為常數。
作為模組的屬性查找的值被假定為常數
import math
import torch
@torch.jit.script
def fn():
return math.pi
ScriptModule 的屬性可以使用
Final[T]
進行註釋,以將其標記為常數
import torch
import torch.nn as nn
class Foo(nn.Module):
# `Final` from the `typing_extensions` module can also be used
a : torch.jit.Final[int]
def __init__(self):
super().__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
支援的常數 Python 類型包括:
int
float
bool
torch.device
torch.layout
torch.dtype
包含支援類型的 tuple
torch.nn.ModuleList
,可以在 TorchScript for 迴圈中使用
模組屬性 (Module Attributes)¶
可以使用 torch.nn.Parameter
封裝器和 register_buffer
將 tensors 指派給模組。如果類型可以被推斷,則編譯後指派給模組的其他值將被新增到編譯後的模組中。TorchScript 中可用的所有 類型 都可以用作模組屬性。Tensor 屬性在語義上與 buffers 相同。空列表和字典以及 None
值的類型無法被推斷,必須透過 PEP 526 風格 的類別註解來指定。如果類型無法被推斷並且沒有明確註解,它將不會作為屬性新增到產生的 ScriptModule
中。
範例
from typing import List, Dict
class Foo(nn.Module):
# `words` is initialized as an empty list, so its type must be specified
words: List[str]
# The type could potentially be inferred if `a_dict` (below) was not
# empty, but this annotation ensures `some_dict` will be made into the
# proper type
some_dict: Dict[str, int]
def __init__(self, a_dict):
super().__init__()
self.words = []
self.some_dict = a_dict
# `int`s can be inferred
self.my_int = 10
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))