快捷方式

torch.export IR 規格

Export IR 是一種用於編譯器的中間表示法 (IR),與 MLIR 和 TorchScript 相似。它專門設計用於表達 PyTorch 程式的語意。 Export IR 主要以簡化的運算列表表示計算,對控制流程等動態性的支援有限。

要建立 Export IR 圖,可以使用前端,透過追蹤專門化的機制來穩健地捕捉 PyTorch 程式。然後,後端可以最佳化和執行產生的 Export IR。現在可以透過 torch.export.export() 來完成。

本文檔中將涵蓋的關鍵概念包括:

  • ExportedProgram:包含 Export IR 程式的資料結構

  • Graph:由節點列表組成。

  • Nodes:代表操作、控制流程以及儲存在此節點上的元資料。

  • Values 由節點產生和消耗。

  • Types 與值和節點相關聯。

  • 還定義了值的大小和記憶體佈局。

假設

此文件假設讀者已充分熟悉 PyTorch,特別是 torch.fx 及其相關工具。因此,本文將不再描述 torch.fx 文件和論文中已有的內容。

什麼是 Export IR

Export IR 是 PyTorch 程式基於圖 (graph) 的中介表示 (intermediate representation, IR)。 Export IR 是建立在 torch.fx.Graph 之上的。換句話說,所有 Export IR 圖也是有效的 FX 圖,並且如果使用標準 FX 語義進行解釋,Export IR 可以被正確地解釋。 其中一個含義是,匯出的圖可以通過標準 FX 程式碼生成轉換為有效的 Python 程式。

本文檔將主要重點介紹 Export IR 在嚴格性方面與 FX 不同的地方,同時跳過與 FX 相似的部分。

ExportedProgram

最上層的 Export IR 結構是 torch.export.ExportedProgram 類別。 它將 PyTorch 模型的計算圖(通常是 torch.nn.Module)與該模型使用的參數或權重捆綁在一起。

torch.export.ExportedProgram 類別的一些值得注意的屬性是

  • graph_module (torch.fx.GraphModule):包含 PyTorch 模型扁平化計算圖的資料結構。 可以通過 ExportedProgram.graph 直接訪問該圖。

  • graph_signature (torch.export.ExportGraphSignature):圖的簽章,指定在圖中使用和修改的參數和緩衝區名稱。 參數和緩衝區不會儲存為圖的屬性,而是提升為圖的輸入。 graph_signature 用於追蹤這些參數和緩衝區的額外資訊。

  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含參數和緩衝區的資料結構。

  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):對於匯出具有資料相關行為的程式,每個節點上的元資料 (metadata) 將包含符號形狀(看起來像 s0, i0)。 此屬性將符號形狀映射到其下限/上限範圍。

圖 (Graph)

Export IR 圖是以 DAG(有向無環圖)形式表示的 PyTorch 程式。 此圖中的每個節點代表特定的計算或操作,並且此圖的邊緣由節點之間的引用組成。

我們可以將圖視為具有以下架構

class Graph:
  nodes: List[Node]

實際上,Export IR 的圖是作為 torch.fx.Graph Python 類別實現的。

Export IR 圖包含以下節點(節點將在下一節中更詳細地描述)

  • 0 個或多個 op 類型為 placeholder 的節點

  • 0 個或多個 op 類型為 call_function 的節點

  • 恰好 1 個 op 類型為 output 的節點

推論: 最小的有效圖將由一個節點組成。 也就是說,節點永遠不會為空。

定義: 圖的 placeholder 節點集合表示 GraphModule 的圖的輸入。 圖的 output 節點表示 GraphModule 的圖的輸出

範例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

以上是圖的文字表示形式,每行代表一個節點。

節點 (Node)

節點表示特定的計算或操作,並在 Python 中使用 torch.fx.Node 類別表示。 節點之間的邊緣表示為通過 Node 類別的 args 屬性對其他節點的直接引用。 使用相同的 FX 機制,我們可以表示計算圖通常需要的以下操作,例如運算符呼叫、佔位符(也稱為輸入)、條件式和迴圈。

節點具有以下架構

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文字格式

如上面的範例所示,請注意每一行都具有以下格式

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以緊湊的格式捕獲 Node 類別中存在的所有內容,但 meta 除外。

具體來說

  • <name> 是節點的名稱,如同它出現在 node.name 中一樣。

  • <op_name>node.op 欄位,它必須是以下之一:<call_function><placeholder><get_attr><output>

  • <target> 是節點的目標,如同 node.target 一樣。 此欄位的含義取決於 op_name

  • args1, … args 4…node.args 元組中列出的內容。 如果列表中的值是 torch.fx.Node,則將特別以開頭的 % 指示。

例如,對 add 運算符的呼叫將顯示為

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x, %y 是另外兩個名稱分別為 x 和 y 的節點。 值得注意的是,字串 torch.op.aten.add.Tensor 表示實際儲存在目標欄位中的可呼叫物件,而不僅僅是它的字串名稱。

此文字格式的最後一行是

return [add]

這是一個 op_name = output 的節點,表示我們正在回傳這一個元素。

call_function

一個 call_function 節點代表對運算子的呼叫。

定義

  • 函數式(Functional): 我們稱一個可呼叫物件為“函數式”,如果它滿足以下所有要求:

    • 非變異性(Non-mutating):運算子不會改變其輸入的值(對於張量,這包括元資料和資料)。

    • 無副作用(No side effects):運算子不會改變從外部可見的狀態,例如改變模組參數的值。

  • 運算子(Operator): 是一個具有預定義模式的函數式可呼叫物件。此類運算子的範例包括函數式 ATen 運算子。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

與原始 FX call_function 的差異

  1. 在 FX 圖中,一個 call_function 可以指向任何可呼叫物件。在 Export IR 中,我們將其限制為僅 ATen 運算子、自訂運算子和控制流程運算子的特定子集。

  2. 在 Export IR 中,常數引數將嵌入到圖中。

  3. 在 FX 圖中,一個 get_attr 節點可以表示讀取儲存在圖模組中的任何屬性。 但是,在 Export IR 中,這僅限於讀取子模組,因為所有參數/緩衝區將作為輸入傳遞到圖模組。

元資料

Node.meta 是一個附加到每個 FX 節點的字典。但是,FX 規範並未指定可以或將會有什麼元資料。Export IR 提供了更強的合約,特別是所有 call_function 節點都保證具有且僅具有以下元資料欄位:

  • node.meta["stack_trace"] 是一個字串,包含引用原始 Python 原始碼的 Python 堆疊追蹤。堆疊追蹤的範例如下:

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述了執行運算的輸出。它可以是 <symint><FakeTensor>List[Union[FakeTensor, SymInt]]None 類型。

  • node.meta["nn_module_stack"] 描述了 torch.nn.Module 的“堆疊追蹤”,如果節點來自 torch.nn.Module 呼叫。 例如,如果一個包含從 torch.nn.Linear 模組內部的一個 torch.nn.Sequential 模組呼叫的 addmm 運算的節點,則 nn_module_stack 看起來會像這樣:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含在分解之前,此節點被呼叫時所來自的 torch 函數或葉節點 torch.nn.Module 類別。 例如,一個包含來自 torch.nn.Linear 模組呼叫的 addmm 運算的節點,會在它們的 source_fn 中包含 torch.nn.Linear,而一個包含來自 torch.nn.functional.Linear 模組呼叫的 addmm 運算的節點,會在它們的 source_fn 中包含 torch.nn.functional.Linear

placeholder

Placeholder 代表圖的輸入。 它的語義與 FX 中的完全相同。 Placeholder 節點必須是圖的節點清單中的前 N 個節點。 N 可以為零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

target 欄位是一個字串,它是輸入的名稱。

args 如果非空,則大小應為 1,代表此輸入的預設值。

元資料

Placeholder 節點也具有 meta[‘val’],就像 call_function 節點一樣。 在這種情況下,val 欄位代表圖預期接收到此輸入參數的輸入形狀/dtype。

output

一個輸出呼叫代表函數中的 return 語句; 因此,它終止了當前圖。 只有一個輸出節點,並且它將始終是圖的最後一個節點。

在 FX 中的表示

output[](args = (%something, …))

這具有與 torch.fx 中完全相同的語義。 args 代表要傳回的節點。

元資料

輸出節點具有與 call_function 節點相同的元資料。

get_attr

get_attr 節點代表從封裝的 torch.fx.GraphModule 讀取子模組。 與來自 torch.fx.symbolic_trace() 的原始 FX 圖不同,在原始 FX 圖中 get_attr 節點用於從頂層 torch.fx.GraphModule 讀取諸如參數和緩衝區之類的屬性,參數和緩衝區作為輸入傳遞到圖模組,並儲存在頂層 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

範例

考慮以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取包含 sin 運算子的子模組 true_graph_0

參考文獻

SymInt

SymInt 是一個物件,它可以是一個文字整數或一個代表整數的符號(在 Python 中由 sympy.Symbol 類別表示)。 當 SymInt 是一個符號時,它描述了一種在編譯時圖不知道的整數類型變數,也就是說,它的值僅在運行時才知道。

FakeTensor

FakeTensor 是一個包含張量元資料的物件。 它可以被視為具有以下元資料。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 欄位是一個整數或 SymInts 的列表。如果存在 SymInts,表示這個 tensor 具有動態形狀。如果存在整數,則假定該 tensor 將具有該確切的靜態形狀。TensorMeta 的 rank 永遠不是動態的。dtype 欄位表示該節點輸出的資料類型。在 Edge IR 中沒有隱式類型提升。FakeTensor 中沒有 strides。

換句話說

  • 如果 node.target 中的運算符返回一個 Tensor,那麼 node.meta['val'] 是一個描述該 tensor 的 FakeTensor。

  • 如果 node.target 中的運算符返回一個 Tensor 的 n-tuple,那麼 node.meta['val'] 是一個描述每個 tensor 的 FakeTensor 的 n-tuple。

  • 如果 node.target 中的運算符返回一個在編譯時已知的 int/float/scalar,那麼 node.meta['val'] 是 None。

  • 如果 node.target 中的運算符返回一個在編譯時未知的 int/float/scalar,那麼 node.meta['val'] 的類型是 SymInt。

例如

  • aten::add 返回一個 Tensor;所以它的規格將是一個 FakeTensor,其 dtype 和 size 是該運算符返回的 tensor 的 dtype 和 size。

  • aten::sym_size 返回一個整數;所以它的 val 將是一個 SymInt,因為它的值僅在運行時可用。

  • max_pool2d_with_indexes 返回一個 (Tensor, Tensor) 的 tuple;所以規格也將是一個由兩個 FakeTensor 物件組成的 tuple,第一個 TensorMeta 描述了返回值的第一個元素等等。

Python 代碼

def add_one(x):
  return torch.ops.aten(x, 1)

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 類型

如果一個類型是葉類型或包含其他 Pytree-able 類型的容器類型,我們將其定義為 "Pytree-able" 類型。

注意

pytree 的概念與 JAX 的 此處 記錄的概念相同

以下類型被定義為 **葉類型**

類型

定義

Tensor

torch.Tensor

Scalar

任何來自 Python 的數值類型,包括整數類型、浮點數類型和零維 tensor。

int

Python int (在 C++ 中綁定為 int64_t)

float

Python float (在 C++ 中綁定為 double)

bool

Python bool

str

Python string

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

Device

torch.device

以下類型被定義為 **容器類型**

類型

定義

Tuple

Python tuple

List

Python list

Dict

具有 Scalar 鍵的 Python dict

NamedTuple

Python namedtuple

Dataclass

必須通過 register_dataclass 註冊

自定義類別

使用 _register_pytree_node 定義的任何自定義類別

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源