torch.fx¶
概述¶
FX 是一個工具套件,供開發人員用來轉換 nn.Module
實例。 FX 主要由三個部分組成:符號追蹤器 (symbolic tracer)、中間表示法 (intermediate representation) 和 Python 程式碼產生 (Python code generation)。 以下是這些組件實際運作的示範
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符號追蹤器 (symbolic tracer) 執行 Python 程式碼的「符號執行 (symbolic execution)」。 它會將虛假的值(稱為 Proxies)輸入到程式碼中。 這些 Proxies 上的操作會被記錄下來。 有關符號追蹤的更多資訊,請參閱 symbolic_trace()
和 Tracer
文件。
中間表示法 (intermediate representation) 是用於存放符號追蹤期間記錄的操作的容器。 它由一個 Nodes 列表組成,這些 Nodes 代表函數輸入、調用點 (callsite)(函數、方法或 torch.nn.Module
實例)和回傳值。 有關 IR 的更多資訊,請參閱 Graph
的文件。 IR 是應用轉換的格式。
Python 程式碼產生 (Python code generation) 是使 FX 成為 Python 到 Python(或 Module 到 Module)轉換工具套件的原因。 對於每個 Graph IR,我們可以建立有效的 Python 程式碼,以匹配 Graph 的語義。 此功能封裝在 GraphModule
中,它是一個 torch.nn.Module
實例,它保存了一個 Graph
以及一個從 Graph 產生的 forward
方法。
總而言之,這個組件管道(符號追蹤 -> 中間表示法 -> 轉換 -> Python 程式碼產生)構成了 FX 的 Python 到 Python 轉換管道。 此外,這些組件可以單獨使用。 例如,符號追蹤可以單獨使用,以捕獲程式碼的形式進行分析(而不是轉換)。 程式碼產生可用於以程式方式產生模型,例如從配置檔案產生。 FX 有很多用途!
一些範例轉換可以在 examples 儲存庫中找到。
編寫轉換¶
什麼是 FX 轉換? 本質上,它是一個如下所示的函數。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的轉換將接收一個 torch.nn.Module
,從中獲取一個 Graph
,進行一些修改,然後回傳一個新的 torch.nn.Module
。 您應該將 FX 轉換回傳的 torch.nn.Module
視為與常規 torch.nn.Module
相同 – 您可以將其傳遞給另一個 FX 轉換,您可以將其傳遞給 TorchScript,或者您可以運行它。 確保 FX 轉換的輸入和輸出都是 torch.nn.Module
將允許可組合性。
注意
也可以修改現有的 GraphModule
而不是建立新的,如下所示
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
請注意,您必須呼叫 GraphModule.recompile()
以使 GraphModule
上產生的 forward()
方法與修改後的 Graph
同步。
鑑於您已傳入一個已追蹤到 Graph
中的 torch.nn.Module
,現在您可以採用兩種主要方法來建立新的 Graph
。
圖表快速入門¶
有關圖語義的完整處理可以在 Graph
文件中找到,但我們將在此處介紹基礎知識。 Graph
是一種資料結構,表示 GraphModule
上的一個方法。 這需要以下資訊
該方法的輸入是什麼?
該方法內部運行的操作是什麼?
該方法的回傳值(即回傳)是什麼?
所有這三個概念都用 Node
實例表示。 讓我們通過一個簡短的範例來看看我們的意思
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
在這裡,我們定義一個名為 MyModule
的模組作為示範,實例化它,以符號追蹤它,然後呼叫 Graph.print_tabular()
方法來印出一個表格,顯示此 Graph
的節點。
運算碼 (opcode)
名稱 (name)
目標 (target)
參數 (args)
關鍵字參數 (kwargs)
佔位符 (placeholder)
x
x
()
{}
get_attr
linear_weight
linear.weight
()
{}
call_function
add_1
<內建函數 add>
(x, linear_weight)
{}
call_module
linear_1
linear
(add_1,)
{}
call_method
relu_1
relu
(linear_1,)
{}
call_function
sum_1
<內建方法 sum …>
(relu_1,)
{‘dim’: -1}
call_function
topk_1
<內建方法 topk …>
(sum_1, 3)
{}
output
output
output
(topk_1,)
{}
我們可以利用這些資訊來回答上面提出的問題。
方法的輸入是什麼? 在 FX 中,方法輸入透過特殊的
placeholder
節點指定。 在此例中,我們有一個單一的placeholder
節點,其target
為x
,表示我們有一個名為 x 的單一 (非 self) 引數。方法中的運算有哪些?
get_attr
、call_function
、call_module
和call_method
節點代表方法中的運算。所有這些語意的完整處理可以在Node
文件中找到。方法的傳回值是什麼?
Graph
中的傳回值由一個特殊的output
節點指定。
鑑於我們現在了解代碼如何在 FX 中表示的基本知識,我們現在可以探索如何編輯 Graph
。
圖形操作¶
直接圖形操作¶
建構這個新 Graph
的一種方法是直接操作你的舊圖形。 為了協助實現這一點,我們可以簡單地取得從符號追蹤取得的 Graph
並修改它。 例如,假設我們希望將 torch.add()
呼叫替換為 torch.mul()
呼叫。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我們也可以進行更複雜的 Graph
重寫,例如刪除或附加節點。 為了協助進行這些轉換,FX 具有轉換圖形的實用函數,可以在 Graph
文件中找到。 下面可以找到使用這些 API 附加 torch.relu()
呼叫的範例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
對於僅包含替換的簡單轉換,您也可以使用 子圖重寫器。
使用 replace_pattern() 進行子圖重寫¶
FX 還在直接圖形操作之上提供了另一個層級的自動化。 replace_pattern()
API 本質上是一個用於編輯 Graph
的「尋找/替換」工具。 它允許您指定一個 pattern
和 replacement
函數,它將追蹤這些函數,在 pattern
圖形中找到運算組的實例,並用 replacement
圖形的副本替換這些實例。 這可以幫助大大自動化繁瑣的圖形操作代碼,因為轉換變得更加複雜,這可能會變得難以處理。
Proxy/Retracing¶
操作 Graph
的另一種方法是重複使用符號追蹤中使用的 Proxy
機制。 例如,假設我們想要編寫一個將 PyTorch 函數分解為較小運算的轉換。 它會將每個 F.relu(x)
呼叫轉換為 (x > 0) * x
。 一種可能性是執行必要的圖形重寫,在 F.relu
之後插入比較和乘法,然後清理原始的 F.relu
。 但是,我們可以透過使用 Proxy
物件自動將運算記錄到 Graph
中來自動化這個過程。
要使用這個方法,我們會將想要插入的操作寫成一般的 PyTorch 程式碼,並以 Proxy
物件作為參數來調用這些程式碼。這些 Proxy
物件會捕捉對它們執行的操作,並將它們附加到 Graph
中。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免顯式地操作圖形之外,使用 Proxy
也能夠讓您將重寫規則指定為原生 Python 程式碼。對於需要大量重寫規則的轉換 (例如 vmap 或 grad),這通常可以提高規則的可讀性和可維護性。請注意,在調用 Proxy
時,我們也傳遞了一個指向底層變數 graph 的 tracer。這樣做的目的是,如果圖形中的操作是 n 元的 (例如,add 是一個二元運算符),則對 Proxy
的調用不會建立圖形 tracer 的多個實例,這可能會導致意外的運行時錯誤。我們建議使用這種方法來使用 Proxy
,尤其是在不能安全地假設底層運算符是一元運算符的情況下。
Interpreter 模式¶
在 FX 中,一種有用的程式碼組織模式是迴圈遍歷 Graph
中的所有 Node
,並執行它們。這可用於多種用途,包括對流經圖形的值進行運行時分析,或透過使用 Proxy
重新追蹤來轉換程式碼。例如,假設我們想要執行一個 GraphModule
,並在運行時記錄節點上的 torch.Tensor
形狀和 dtype 屬性。它可能看起來像這樣
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
正如您所看到的,FX 的完整 interpreter 並不那麼複雜,但它可能非常有用。為了簡化使用這種模式,我們提供了 Interpreter
類別,它以一種可以透過方法覆寫來覆蓋 interpreter 執行某些方面的方式,包含了上述邏輯。
除了執行操作,我們還可以透過 interpreter 提供 Proxy
值來產生一個新的 Graph。同樣地,我們提供 Transformer
類別來包含這種模式。Transformer
的行為與 Interpreter
類似,但不是調用 run
方法來從 Module 取得具體的輸出值,而是調用 Transformer.transform()
方法來返回一個新的 GraphModule
,該模組受到您安裝為覆寫方法的任何轉換規則的約束。
除錯¶
簡介¶
通常在撰寫轉換的過程中,我們的程式碼不會完全正確。在這種情況下,我們可能需要進行一些除錯。關鍵是倒推:首先,檢查調用產生的模組的結果,以證明或反駁正確性。然後,檢查和除錯產生的程式碼。然後,除錯導致產生程式碼的轉換過程。
如果您不熟悉除錯器,請參閱輔助章節 可用的除錯器。
轉換撰寫中常見的陷阱¶
不確定的
set
迭代順序。在 Python 中,set
資料類型是無序的。使用set
來包含物件集合 (例如Node
s) 可能會導致意外的不確定性。一個例子是迭代一組Node
s 以將它們插入到Graph
中。由於set
資料類型是無序的,因此輸出程式中的操作順序將是不確定的,並且可能會在程式調用之間發生變化。建議的替代方法是使用dict
資料類型,根據 Python 3.7 (和 cPython 3.6) 的說法,它是 插入有序的。dict
可以透過將要去除重複的值儲存在dict
的鍵中,來等效地用作 set。
檢查模組的正確性¶
由於大多數深度學習模組的輸出都由浮點數 torch.Tensor
實例組成,因此檢查兩個 torch.nn.Module
的結果是否相等,不像執行簡單的相等性檢查那麼直接。為了說明這一點,讓我們使用一個例子:
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在這裡,我們試圖使用 ==
相等運算符來檢查兩個深度學習模型的數值是否相等。 然而,這並沒有明確定義,不僅因為該運算符返回的是 tensor 而不是布林值,而且由於浮點數值的比較應該使用誤差範圍(或 epsilon)來考慮浮點數運算的不可交換性 (詳情請參考 這裡)。 我們可以使用 torch.allclose()
來代替,它將給我們一個近似比較,考慮到相對和絕對容差閾值。
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
這是我們工具箱中的第一個工具,用於檢查轉換後的模組與參考實現相比,是否按照我們的預期運作。
偵錯生成的程式碼¶
由於 FX 會在 GraphModule
上生成 forward()
函式,因此使用傳統的偵錯技術 (例如 print
語句或 pdb
) 並不那麼直接。 幸運的是,我們可以使用幾種技術來偵錯生成的程式碼。
使用 pdb
¶
調用 pdb
以單步執行正在運行的程式。 儘管表示 Graph
的程式碼不在任何源文件中,但我們仍然可以在調用 forward 傳遞時使用 pdb
手動單步執行它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
列印生成的程式碼¶
如果您想多次運行相同的程式碼,那麼使用 pdb
單步執行到正確的程式碼可能會有點繁瑣。 在這種情況下,一種方法是簡單地將生成的 forward
傳遞複製貼上到您的程式碼中,然後從那裡檢查它。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule
中的 to_folder
函式¶
GraphModule.to_folder()
是 GraphModule
中的一個方法,允許您將生成的 FX 程式碼轉儲到一個資料夾中。 儘管像在 列印生成的程式碼 中那樣,將 forward 傳遞複製到程式碼中通常就足夠了,但使用 to_folder
檢查模組和參數可能會更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
在運行上面的範例之後,我們可以查看 foo/module.py
中的程式碼,並根據需要修改它 (例如,添加 print
語句或使用 pdb
) 來偵錯生成的程式碼。
偵錯轉換¶
現在我們已經確定轉換正在建立不正確的程式碼,是時候偵錯轉換本身了。 首先,我們將檢查文檔中的 符號追蹤的限制 章節。 一旦我們驗證追蹤按預期運作,目標就變成找出我們的 GraphModule
轉換期間出了什麼問題。 編寫轉換 中可能會有一個快速的答案,但如果沒有,有幾種方法可以檢查我們追蹤的模組。
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面的實用函式,我們可以比較在應用轉換之前和之後追蹤的模組。 有時,一個簡單的可視化比較就足以追蹤到一個錯誤。 如果仍然不清楚發生了什麼錯誤,像 pdb
這樣的偵錯器可能是一個不錯的下一步。
基於上面的範例,考慮以下程式碼:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的範例,假設調用 print(traced)
向我們表明我們的轉換中存在錯誤。 我們想使用偵錯器找出發生了什麼錯誤。 我們啟動一個 pdb
會話。 我們可以透過在 transform_graph(traced)
上中斷來查看轉換期間發生了什麼,然後按下 s
以“單步執行”對 transform_graph(traced)
的調用。
我們也可以透過編輯 print_tabular
方法來列印 Graph 中 Node 的不同屬性來獲得好運。(例如,我們可能想查看 Node 的 input_nodes
和 users
。)
可用的偵錯器¶
最常見的 Python 除錯器是 pdb。你可以透過在命令列輸入 python -m pdb FILENAME.py
,以 pdb
啟動你的程式進入「除錯模式」,其中 FILENAME
是你想要除錯的檔案名稱。 之後,你可以使用 pdb
的 除錯器指令,逐步執行你的程式。 通常會在啟動 pdb
時設定一個中斷點 (b LINE-NUMBER
),然後呼叫 c
來執行程式直到該點。 這樣可以避免你必須逐步執行每一行程式碼(使用 s
或 n
)才能到達你想檢查的程式碼部分。 或者,你可以在你想中斷的那行程式碼之前寫入 import pdb; pdb.set_trace()
。 如果你加入 pdb.set_trace()
,你的程式將會在執行時自動以除錯模式啟動。(換句話說,你可以直接在命令列輸入 python FILENAME.py
,而不是 python -m pdb FILENAME.py
。) 一旦你在除錯模式下執行你的檔案,你可以逐步執行程式碼,並使用特定的指令來檢查你程式的內部狀態。 網路上有很多關於 pdb
的優秀教學,包括 RealPython 的 “Python Debugging With Pdb”。
像 PyCharm 或 VSCode 這樣的 IDE 通常內建除錯器。 在你的 IDE 中,你可以選擇 a) 透過在你的 IDE 中拉出一個終端機視窗(例如 VSCode 中的 View → Terminal)來使用 pdb
,或者 b) 使用內建的除錯器(通常是 pdb
的圖形化封裝)。
符號追蹤的限制¶
FX 使用一種符號追蹤系統(又稱 符號執行)來捕獲程式在可轉換/可分析形式中的語義。 該系統是追蹤,因為它執行程式(實際上是一個 torch.nn.Module
或函式)來記錄操作。 它是符號的,因為在此執行期間流經程式的資料不是真實資料,而是符號(FX 術語中的 Proxy
)。
雖然符號追蹤適用於大多數神經網路程式碼,但它有一些限制。
動態控制流程¶
符號追蹤的主要限制是它目前不支援動態控制流程。 也就是說,迴圈或 if
語句,其中條件可能取決於程式的輸入值。
例如,讓我們檢查以下程式
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
語句的條件依賴於 x.sum()
的值,而 x.sum()
的值又依賴於 x
的值,而 x
是一個函式輸入。 由於 x
可以改變(例如,如果你傳遞一個新的輸入張量到追蹤的函式),這就是動態控制流程。 追蹤資訊會往回追溯你的程式碼,以顯示這種情況發生的位置。
靜態控制流程¶
另一方面,支援所謂的靜態控制流程。 靜態控制流程是不會跨多個呼叫而改變的迴圈或 if
語句。 通常,在 PyTorch 程式中,這種控制流程是由於程式碼根據超參數來決定模型架構。 作為一個具體的例子
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if 語句 if self.do_activation
不依賴於任何函式輸入,因此它是靜態的。 do_activation
可以被認為是一個超參數,並且具有不同參數值的 MyModule
不同實例的追蹤具有不同的程式碼。 這是一個有效的模式,符號追蹤支援它。
許多動態控制流程的實例在語義上是靜態控制流程。 透過移除對輸入值的資料依賴性,例如透過將值移動到 Module
屬性或在符號追蹤期間將具體值綁定到參數,可以使這些實例支援符號追蹤
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
在真正的動態控制流程的情況下,包含此程式碼的程式碼段可以追蹤為對方法的呼叫(請參閱 使用 Tracer 類別自訂追蹤)或函式(請參閱 wrap()
),而不是追蹤它們。
非 torch
函式¶
FX 使用 __torch_function__
作為它攔截呼叫的機制(有關更多資訊,請參閱 技術概述)。 某些函式,例如內建的 Python 函式或 math
模組中的函式,不在 __torch_function__
的涵蓋範圍內,但我們仍然希望在符號追蹤中捕獲它們。 例如
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
該錯誤告訴我們內建函式 len
不受支援。 我們可以使用 wrap()
API,使這種類型的函式被記錄在追蹤中,作為直接呼叫
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer
類別自訂追蹤¶
Tracer
類別是 symbolic_trace
實作的基礎類別。 透過子類化 Tracer 可以自訂追蹤的行為,如下所示
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
葉模組¶
葉模組是在符號追蹤中顯示為呼叫而不是被追蹤的模組。 預設的葉模組集合是標準 torch.nn
模組實例的集合。 例如
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以透過覆寫 Tracer.is_leaf_module()
來客製化葉模組的集合。
其他¶
Tensor 建構子(例如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
、torch.sparse_coo_tensor
)目前無法追蹤。可以使用決定性的建構子(
zeros
、ones
),並且它們產生之值會作為常數嵌入到追蹤中。只有當這些建構子的參數引用到動態輸入大小時,才會出現問題。在這種情況下,ones_like
或zeros_like
可能是一個可行的替代方案。非決定性的建構子(
rand
、randn
)將在追蹤中嵌入單個隨機值。這可能不是預期的行為。一種解決方法是將torch.randn
包裹在torch.fx.wrap
函數中,並改為呼叫該函數。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行為可能會在未來的版本中修復。
類型註釋
支援 Python 3 風格的類型註釋(例如
func(x : torch.Tensor, y : int) -> torch.Tensor
),並且符號追蹤將保留這些註釋。目前不支援 Python 2 風格的註解類型註釋
# type: (torch.Tensor, int) -> torch.Tensor
。目前不支援函數內區域變數名稱上的註釋。
關於
training
標誌和子模組的注意事項當使用函數式操作(如
torch.nn.functional.dropout
)時,通常會將 training 參數作為self.training
傳遞。在 FX 追蹤期間,這很可能會被烘焙為常數值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
但是,當使用標準
nn.Dropout()
子模組時,training 標誌會被封裝,並且由於保留了nn.Module
物件模型,因此可以更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
由於這種差異,請考慮將與
training
標誌動態互動的模組標記為葉節點模組。
API 參考¶
- torch.fx.symbolic_trace(root, concrete_args=None)[source][source]¶
符號追蹤 API
給定一個
nn.Module
或函數實例root
,此函數將返回一個GraphModule
,該GraphModule
是通過記錄在追蹤root
時看到的運算來建構的。concrete_args
允許您部分專門化您的函數,無論是為了移除控制流還是資料結構。例如
def f(a, b): if b == True: return a else: return a * 2
由於控制流的存在,FX 通常無法追蹤此程式碼。但是,我們可以使用 concrete_args 來專門化 b 的值,以追蹤此程式碼
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
請注意,儘管您仍然可以傳入不同的 b 值,但它們將被忽略。
我們也可以使用 concrete_args 從我們的函數中消除資料結構處理。這將使用 pytrees 來扁平化您的輸入。為了避免過度專門化,請為不應專門化的值傳入 fx.PH。例如
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) assert f({"a": 1, "b": 2, "c": 4}) == 7
- 參數
root (Union[torch.nn.Module, Callable]) – 要追蹤並轉換為 Graph 表示的模組或函數。
concrete_args (Optional[Dict[str, any]]) – 要部分專門化的輸入
- 返回
從
root
記錄的運算建立的模組。- 返回類型
注意
保證此 API 的向後相容性。
- torch.fx.wrap(fn_or_name)[source][source]¶
可以在模組級別範圍呼叫此函數,以將 fn_or_name 註冊為「葉節點函數」。 「葉節點函數」將在 FX 追蹤中保留為 CallFunction 節點,而不是被追蹤
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函數也可以等效地用作裝飾器
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
一個被包裝的函數可以被認為是一個“葉節點函數”,類似於“葉節點模組”的概念,也就是說,它們是在 FX 追蹤中保留為呼叫而不是被追蹤的函數。
- 參數
fn_or_name (Union[str, Callable]) – 要在呼叫時插入圖形中的函數或全域函數的名稱
注意
保證此 API 的向後相容性。
- class torch.fx.GraphModule(*args, **kwargs)[source][source]¶
GraphModule 是一個由 fx.Graph 產生的 nn.Module。GraphModule 有一個
graph
屬性,以及由該graph
產生的code
和forward
屬性。警告
當
graph
被重新賦值時,code
和forward
將會自動重新產生。然而,如果您在沒有重新賦值graph
屬性本身的情況下,編輯graph
的內容,您必須呼叫recompile()
以更新產生的程式碼。注意
保證此 API 的向後相容性。
- __init__(root, graph, class_name='GraphModule')[原始碼][原始碼]¶
建構一個 GraphModule。
- 參數
root (Union[torch.nn.Module, Dict[str, Any]) –
root
可以是一個 nn.Module 實例,或是一個將字串映射到任何屬性類型的 Dict。如果root
是一個 Module,則在 Graph 的 Nodes 的target
欄位中,任何對基於 Module 的物件的引用(透過完整名稱)將會從root
的 Module 階層結構中的相應位置複製到 GraphModule 的 module 階層結構中。如果root
是一個 dict,則在 Node 的target
中找到的完整名稱將會在 dict 的鍵中直接查找。Dict 映射到的物件將會被複製到 GraphModule 的 module 階層結構中的適當位置。graph (Graph) –
graph
包含此 GraphModule 應該用於程式碼產生的節點class_name (str) –
name
表示此 GraphModule 的名稱,用於除錯目的。如果未設定,所有錯誤訊息都將報告為源自GraphModule
。將此設定為root
的原始名稱或在轉換上下文中具有意義的名稱可能會很有幫助。
注意
保證此 API 的向後相容性。
- add_submodule(target, m)[原始碼][原始碼]¶
將給定的子模組添加到
self
。如果
target
的子路徑不存在,則會在此處安裝空的 Module。- 參數
- 返回
- 子模組是否可以插入。對於
此方法返回 True,由
target
表示的鏈中的每個物件必須 a) 尚不存在,或 b) 參考一個nn.Module
(不是參數或其他屬性)
- 返回類型
注意
保證此 API 的向後相容性。
- delete_all_unused_submodules()[原始碼][原始碼]¶
從
self
刪除所有未使用的子模組。如果以下任何一項為真,則 Module 會被視為「已使用」:1. 它有已使用的子節點 2. 它的 forward 透過
call_module
節點直接呼叫 3. 它有一個非 Module 屬性,該屬性從get_attr
節點使用可以呼叫此方法來清理
nn.Module
,而無需在每個未使用的子模組上手動呼叫delete_submodule
。注意
保證此 API 的向後相容性。
- delete_submodule(target)[原始碼][原始碼]¶
從
self
刪除給定的子模組。如果
target
不是有效的目標,則不會刪除該模組。- 參數
target (str) – 新子模組的完整字串名稱(請參閱
nn.Module.get_submodule
中的範例,了解如何指定完整字串。)- 返回
- 目標字串是否引用了
我們想要刪除的子模組。返回值
False
表示target
不是對子模組的有效引用。
- 返回類型
注意
保證此 API 的向後相容性。
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False)[原始碼][原始碼]¶
傳回為目前的 GraphModule 及其子 GraphModules 產生的 Python 程式碼
警告
此 API 是實驗性的,不具有回溯相容性。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[原始碼][原始碼]¶
Graph
是 FX 中間表示法中使用的主要資料結構。它由一系列Node
組成,每個Node
代表呼叫點(或其他語法結構)。Node
的列表共同構成一個有效的 Python 函式。例如,以下程式碼
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
將產生以下 Graph
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1
關於
Graph
中表示的運算的語義,請參閱Node
。注意
保證此 API 的向後相容性。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[原始碼][原始碼]¶
建構一個空的 Graph。
注意
保證此 API 的向後相容性。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[原始碼][原始碼]¶
將
call_function
Node
插入到Graph
中。call_function
節點表示對 Python 可呼叫物件的呼叫,由the_function
指定。- 參數
the_function (Callable[..., Any]) – 要呼叫的函式。可以是任何 PyTorch 運算子、Python 函式,或
builtins
或operator
命名空間的成員。args (Optional[Tuple[Argument, ...]]) – 要傳遞給被呼叫函式的位置引數。
kwargs (Optional[Dict[str, Argument]]) – 要傳遞給被呼叫函式的關鍵字引數
type_expr (Optional[Any]) – 一個可選的型別註釋,代表此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
call_function
節點。- 返回類型
注意
相同的插入點和型別表示式規則適用於此方法,如
Graph.create_node()
。注意
保證此 API 的向後相容性。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[原始碼][原始碼]¶
插入一個
call_method
Node
到Graph
中。call_method
節點代表呼叫args
中第 0 個元素的指定方法。- 參數
method_name (str) – 要應用於 self 參數的方法名稱。 例如,如果 args[0] 是一個代表
Tensor
的Node
,那麼要呼叫該Tensor
上的relu()
,請將relu
傳遞給method_name
。args (Optional[Tuple[Argument, ...]]) – 要傳遞給所呼叫方法的位置參數。 請注意,這應該包含一個
self
參數。kwargs (Optional[Dict[str, Argument]]) – 要傳遞給所呼叫方法的關鍵字參數
type_expr (Optional[Any]) – 一個可選的型別註釋,代表此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
call_method
節點。- 返回類型
注意
相同的插入點和型別表示式規則適用於此方法,如
Graph.create_node()
。注意
保證此 API 的向後相容性。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[原始碼][原始碼]¶
插入一個
call_module
Node
到Graph
中。call_module
節點代表呼叫Module
階層中Module
的 forward() 函式。- 參數
module_name (str) – 要呼叫的
Module
階層中Module
的完整名稱。 例如,如果追蹤的Module
有一個名為foo
的子模組,該子模組有一個名為bar
的子模組,則應將完整名稱foo.bar
作為module_name
傳遞以呼叫該模組。args (Optional[Tuple[Argument, ...]]) – 要傳遞給所呼叫方法的位置參數。 請注意,這不應該包含
self
參數。kwargs (Optional[Dict[str, Argument]]) – 要傳遞給所呼叫方法的關鍵字參數
type_expr (Optional[Any]) – 一個可選的型別註釋,代表此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
call_module
節點。- 返回類型
注意
相同的插入點和型別表示式規則適用於此方法,如
Graph.create_node()
。注意
保證此 API 的向後相容性。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[原始碼][原始碼]¶
建立一個
Node
並將其新增到Graph
中目前插入點。 請注意,可以使用Graph.inserting_before()
和Graph.inserting_after()
設定目前的插入點。- 參數
op (str) – 此節點的運算碼。 為 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 之一。 這些運算碼的語意在
Graph
文件字串中描述。args (Optional[Tuple[Argument, ...]]) – 是此節點的引數元組。
kwargs (Optional[Dict[str, Argument]]) – 此節點的 kwargs
name (Optional[str]) –
Node
的可選字串名稱。 這將影響 Python 產生的程式碼中指派給的值的名稱。type_expr (Optional[Any]) – 一個可選的型別註釋,代表此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的節點。
- 返回類型
注意
保證此 API 的向後相容性。
- eliminate_dead_code(is_impure_node=None)[原始碼][原始碼]¶
根據每個節點的使用者數量以及節點是否具有任何副作用,從圖表中移除所有無效程式碼。在呼叫之前,圖表必須經過拓撲排序。
- 參數
- 返回
圖表是否因這次傳遞而發生變更。
- 返回類型
範例
在無效程式碼被消除之前,下方 a = x + 1 中的 a 沒有使用者,因此可以從圖表中消除,而不會產生任何影響。
def forward(self, x): a = x + 1 return x + self.attr_1
在無效程式碼被消除後,a = x + 1 已被移除,而 forward 的其餘部分則保持不變。
def forward(self, x): return x + self.attr_1
警告
無效程式碼消除具有一些啟發法,可以避免移除具有副作用的節點(請參閱 Node.is_impure),但一般而言,覆蓋率非常差,因此您應該假設此方法不可靠,除非您知道您的 FX 圖表完全由功能性操作組成,或者您提供自己的自訂函式來偵測具有副作用的節點。
注意
保證此 API 的向後相容性。
- erase_node(to_erase)[原始碼][原始碼]¶
從
Graph
中移除Node
。如果Graph
中仍然存在該節點的使用者,則會拋出例外狀況。- 參數
to_erase (Node) – 要從
Graph
中移除的Node
。
注意
保證此 API 的向後相容性。
- find_nodes(*, op, target=None, sort=True)[原始碼][原始碼]¶
允許快速查詢節點
- 參數
- 返回
具有請求的操作和目標的可迭代節點。
警告
此 API 是實驗性的,不具有回溯相容性。
- get_attr(qualified_name, type_expr=None)[原始碼][原始碼]¶
將
get_attr
節點插入圖表中。get_attr
Node
代表從Module
階層中提取屬性。- 參數
qualified_name (str) – 要檢索的屬性的完整名稱。例如,如果追蹤的 Module 具有一個名為
foo
的子模組,該子模組具有一個名為bar
的子模組,該子模組具有一個名為baz
的屬性,則應該將完整名稱foo.bar.baz
作為qualified_name
傳遞。type_expr (Optional[Any]) – 一個可選的型別註釋,代表此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
get_attr
節點。- 返回類型
注意
此方法套用與
Graph.create_node
相同的插入點和類型表達式規則。注意
保證此 API 的向後相容性。
- graph_copy(g, val_map, return_output_node=False)[原始碼][原始碼]¶
將給定圖表中的所有節點複製到
self
中。- 參數
- 返回
如果
g
擁有一個output
節點,則self
中目前的值會等同於g
中的輸出值。否則為None
。- 返回類型
Optional[Union[Tuple[Optional[Union[Tuple[Argument, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], …], Sequence[Optional[Union[Tuple[Argument, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[str, Optional[Union[Tuple[Argument, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
保證此 API 的向後相容性。
- inserting_after(n=None)[原始碼][原始碼]¶
- 設定 create_node 及相關方法將插入到圖 (graph) 中的位置。
當在 'with' 語句中使用時,這將暫時設定插入點,並在 with 語句結束時還原它。
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
參數
- n (Optional[Node]): 要在其後插入的節點。如果為 None,則將在
整個圖的開頭之後插入。
- 返回
一個資源管理器,它將在
__exit__
上還原插入點。
注意
保證此 API 的向後相容性。
- inserting_before(n=None)[原始碼][原始碼]¶
- 設定 create_node 及相關方法將插入到圖 (graph) 中的位置。
當在 'with' 語句中使用時,這將暫時設定插入點,並在 with 語句結束時還原它。
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
參數
- n (Optional[Node]): 要在其前插入的節點。如果為 None,則將在
整個圖的開頭之後插入。
- 返回
一個資源管理器,它將在
__exit__
上還原插入點。
注意
保證此 API 的向後相容性。
- lint()[原始碼][原始碼]¶
對此圖執行各種檢查,以確保其格式正確。特別是: - 檢查節點是否具有正確的所有權 (由該圖擁有) - 檢查節點是否以拓撲順序出現 - 如果此圖具有擁有的 GraphModule,則檢查目標是否存在於該 GraphModule 中
注意
保證此 API 的向後相容性。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[原始碼][原始碼]¶
將節點從一個圖複製到另一個圖。
arg_transform
需要將參數從節點的圖轉換為 self 的圖。範例# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- 參數
- 返回類型
注意
保證此 API 的向後相容性。
- property nodes: _node_list¶
取得構成此圖的節點列表。
請注意,此
Node
列表表示形式是雙向鏈結串列。迭代期間的變更 (例如,刪除節點、新增節點) 是安全的。- 返回
節點的雙向鏈結串列。請注意,可以在此列表上調用
reversed
以切換迭代順序。
- on_generate_code(make_transformer)[原始碼][原始碼]¶
註冊在產生 Python 代碼時的轉換器函數
- 參數
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])
一個傳回要註冊的代碼轉換器的函數。此函數由 on_generate_code 呼叫以取得代碼轉換器。
此函數也作為其輸入提供當前註冊的代碼轉換器 (如果沒有註冊任何內容,則為 None),以防不希望覆寫它。這對於將代碼轉換器鏈接在一起很有用。
- 返回
一個上下文管理器,當在 with 語句中使用時,會自動還原先前註冊的代碼轉換器。
範例
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb(current_trans(body) if current_trans else body) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函數也可以用作上下文管理器,其好處是可以自動還原先前註冊的代碼轉換器
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 是實驗性的,不具有回溯相容性。
- output(result, type_expr=None)[原始碼][原始碼]¶
將
output
Node
插入到Graph
中。output
節點表示 Python 代碼中的return
語句。result
是應傳回的值。- 參數
result (Argument) – 要傳回的值。
type_expr (Optional[Any]) – 一個可選的型別註釋,代表此節點輸出將具有的 Python 型別。
注意
此方法套用與
Graph.create_node
相同的插入點和類型表達式規則。注意
保證此 API 的向後相容性。
- placeholder(name, type_expr=None, default_value)[原始碼][原始碼]¶
將
placeholder
節點插入到圖中。placeholder
表示函數輸入。- 參數
name (str) – 輸入值的名稱。這對應於此
Graph
代表的函式的位置引數名稱。type_expr (Optional[Any]) – 一個可選的類型註釋,表示此節點的輸出將具有的 Python 類型。在某些情況下,這對於正確的程式碼生成是必需的(例如,當函式後續在 TorchScript 編譯中使用時)。
default_value (Any) – 此函式引數應採用的預設值。注意:為了允許 None 作為預設值,應將 inspect.Signature.empty 作為此引數傳遞,以指定參數_不_具有預設值。
- 返回類型
注意
此方法套用與
Graph.create_node
相同的插入點和類型表達式規則。注意
保證此 API 的向後相容性。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[原始碼][原始碼]¶
將此
Graph
轉換為有效的 Python 程式碼。- 參數
root_module (str) – 在其上查找限定名稱目標的根模組的名稱。 這通常是 ‘self’。
- 返回
src:表示物件的 Python 原始碼 globals:src 中全域名稱的字典 -> 它們引用的物件。
- 返回類型
一個 PythonCode 物件,由兩個欄位組成
注意
保證此 API 的向後相容性。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[原始碼][原始碼]¶
Node
是一種資料結構,表示Graph
內的個別操作。 在大多數情況下,節點表示對各種實體(例如運算符、方法和模組)的調用點(某些例外情況包括指定函式輸入和輸出的節點)。 每個Node
都有一個由其op
屬性指定的函式。 每個op
值的Node
語義如下placeholder
表示一個函式輸入。name
屬性指定此值將採用的名稱。target
類似地是引數的名稱。args
包含:1) 沒有任何內容,或 2) 表示函式輸入的預設參數的單個引數。kwargs
是無關緊要的。 佔位符對應於圖形列印輸出中的函式參數(例如x
)。get_attr
從模組層次結構中檢索參數。name
類似地是賦予提取結果的名稱。target
是參數在模組層次結構中的位置的完整限定名稱。args
和kwargs
是無關緊要的call_function
將一個自由函式應用於某些值。name
類似地是要賦予值的名稱。target
是要應用的函式。args
和kwargs
代表函式的參數,遵循 Python 的呼叫慣例。call_module
將模組層次結構中的模組的forward()
方法應用於給定的參數。name
與先前相同。target
是要呼叫的模組層次結構中模組的完整名稱。args
和kwargs
代表要調用模組的參數,不包括 self 參數。call_method
呼叫值上的方法。name
類似。target
是要應用於self
參數的方法的字串名稱。args
和kwargs
代表要調用模組的參數,包括 self 參數。output
包含追蹤函式的輸出,位於其args[0]
屬性中。這對應於 Graph 列印輸出中的「return」語句。
注意
保證此 API 的向後相容性。
- property all_input_nodes: List[Node]¶
傳回作為此 Node 輸入的所有 Node。這等同於迭代
args
和kwargs
並僅收集作為 Node 的值。- 返回
出現在此
Node
的args
和kwargs
中的Nodes
列表,按照該順序排列。
- append(x)[source][source]¶
將
x
插入到圖形中節點列表中的此節點之後。等同於self.next.prepend(x)
- 參數
x (Node) – 要放在此節點之後的節點。必須是同一個圖形的成員。
注意
保證此 API 的向後相容性。
- property args: Tuple[Optional[Union[Tuple[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...], Sequence[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[str, Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...]¶
此
Node
的參數元組。參數的解釋取決於節點的運算碼。有關更多資訊,請參閱Node
文件字串。允許賦值給此屬性。所有使用和使用者的帳戶都會在賦值時自動更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[source][source]¶
傳回
self
的描述性字串表示形式。此方法可以在沒有參數的情況下用作偵錯工具。
此函式也在
Graph
的__str__
方法中於內部使用。總而言之,placeholder_names
和maybe_return_typename
中的字串構成了此 Graph 周圍的 GraphModule 中自動生成的forward
函式的簽名。placeholder_names
和maybe_return_typename
不應以其他方式使用。- 參數
- 返回
- 如果 1) 我們將
format_node
用作內部輔助函式 在
Graph
的__str__
方法中,且 2)self
是一個佔位符 Node,則傳回None
。否則,傳回目前 Node 的描述性字串表示形式。
- 如果 1) 我們將
- 返回類型
注意
保證此 API 的向後相容性。
- insert_arg(idx, arg)[source][source]¶
使用給定的索引將一個位置參數插入到參數列表中。
- 參數
idx (int) – 要在其前面插入元素的
self.args
中的元素索引。arg (引數 Argument) – 要插入到
args
中的新引數值。
注意
保證此 API 的向後相容性。
- is_impure()[原始碼][原始碼]¶
返回此操作是否為 impure,也就是說,如果其操作是一個 placeholder 或 output,或者如果是一個 impure 的 call_function 或 call_module。
- 返回
操作是否為 impure。
- 返回類型
警告
此 API 是實驗性的,不具有回溯相容性。
- property kwargs: Dict[str, Optional[Union[Tuple[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...], Sequence[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[str, Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]]¶
此
Node
的關鍵字引數字典。引數的解釋取決於節點的 opcode。請參閱Node
docstring 以取得更多資訊。允許賦值給此屬性。所有使用和使用者的帳戶都會在賦值時自動更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[原始碼][原始碼]¶
返回 Python targets 的正規化引數。這意味著 args/kwargs 將與 module/functional 的簽名匹配,並且如果 normalize_to_only_use_kwargs 為 true,則僅以位置順序返回 kwargs。 也會填入預設值。 不支援僅限位置的參數或變數參數。
支援模組呼叫。
可能需要 arg_types 和 kwarg_types 才能消除過載的歧義。
- 參數
root (torch.nn.Module) – 用於解析模組目標的模組。
arg_types (Optional[Tuple[Any]]) – args 的 arg 類型 Tuple
kwarg_types (Optional[Dict[str, Any]]) – kwargs 的 arg 類型 Dict
normalize_to_only_use_kwargs (bool) – 是否正規化為僅使用 kwargs。
- 返回
返回 NamedTuple ArgsKwargsPair,如果未成功,則返回 None。
- 返回類型
Optional[ArgsKwargsPair]
警告
此 API 是實驗性的,不具有回溯相容性。
- prepend(x)[原始碼][原始碼]¶
將 x 插入到圖中節點列表中的此節點之前。範例
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 參數
x (Node) – 要放在此節點之前的節點。 必須是同一圖的成員。
注意
保證此 API 的向後相容性。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[原始碼][原始碼]¶
用 Node
replace_with
替換圖中所有對self
的使用。- 參數
- 返回
對其進行此更改的 Nodes 列表。
- 返回類型
注意
保證此 API 的向後相容性。
- replace_input_with(old_input, new_input)[原始碼][原始碼]¶
迴圈遍歷
self
的輸入節點,並用new_input
替換所有old_input
的實例。注意
保證此 API 的向後相容性。
- property stack_trace: Optional[str]¶
傳回追蹤期間記錄的 Python 堆疊追蹤(如果有的話)。如果使用 fx.Tracer 進行追蹤,則此屬性通常由 Tracer.create_proxy 填入。若要在追蹤期間記錄堆疊追蹤以進行偵錯,請在 Tracer 實例上設定 record_stack_traces = True。如果使用 dynamo 進行追蹤,則此屬性預設會由 OutputGraph.create_proxy 填入。
stack_trace 會將最內層的 frame 放在字串的結尾。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[原始碼][原始碼]¶
Tracer
類別實作torch.fx.symbolic_trace
的符號追蹤功能。呼叫symbolic_trace(m)
相當於Tracer().trace(m)
。可以將 Tracer 子類化,以覆寫追蹤程序的各種行為。可以覆寫的不同行為在本類別的方法的 docstring 中說明。
注意
保證此 API 的向後相容性。
- call_module(m, forward, args, kwargs)[原始碼][原始碼]¶
此方法指定
Tracer
在遇到呼叫nn.Module
實例時的行為。預設行為是透過
is_leaf_module
檢查呼叫的模組是否為葉模組。如果是,則發出call_module
節點,指向Graph
中的m
。否則,正常呼叫Module
,並追蹤其forward
函式中的運算。可以覆寫此方法,例如建立巢狀追蹤的 GraphModule,或在跨
Module
邊界追蹤時的任何其他所需行為。- 參數
m (Module) – 要發出呼叫的模組
forward (Callable) – 要叫用的
Module
的 forward() 方法args (Tuple) – 模組呼叫位置的 args
kwargs (Dict) – 模組呼叫位置的 kwargs
- 返回
Module 呼叫的回傳值。如果發出了
call_module
節點,則這是Proxy
值。否則,它是Module
叫用所傳回的任何值。- 返回類型
注意
保證此 API 的向後相容性。
- create_arg(a)[原始碼][原始碼]¶
一種方法,用於指定在準備要用作
Graph
中節點的引數的值時的追蹤行為。預設行為包括
反覆運算集合類型(例如 tuple、list、dict),並以遞迴方式呼叫元素上的
create_args
。給定 Proxy 物件,傳回對基礎 IR
Node
的參考給定非 Proxy Tensor 物件,發出各種案例的 IR
對於 Parameter,發出指向該 Parameter 的
get_attr
節點對於非 Parameter Tensor,將 Tensor 儲存在指向該屬性的特殊屬性中。
可以覆寫此方法以支援更多類型。
- 參數
a (Any) – 要作為
Graph
中的Argument
發出的值。- 返回
將值
a
轉換成適當的Argument
- 返回類型
Argument (引數)
注意
保證此 API 的向後相容性。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source][source]¶
建立對應於
root
Module 簽名的placeholder
節點。此方法會檢視 root 的簽名並相應地發出這些節點,同時支援*args
和**kwargs
。警告
此 API 是實驗性的,不具有回溯相容性。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]¶
根據目標 (target)、引數 (args)、關鍵字引數 (kwargs) 和名稱 (name) 插入圖形節點。
可以覆寫此方法來對節點建立中使用的值進行額外的檢查、驗證或修改。例如,可能希望禁止記錄原地 (in-place) 操作。
注意
保證此 API 的向後相容性。
- 返回類型
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]¶
從給定的引數建立一個 Node,然後回傳包裝在 Proxy 物件中的 Node。
如果 kind = ‘placeholder’,則我們正在建立一個代表函式參數的 Node。如果我們需要編碼預設參數,我們使用
args
tuple。args
對於placeholder
Node 來說通常是空的。注意
保證此 API 的向後相容性。
- get_fresh_qualname(prefix)[source][source]¶
取得 prefix 的新名稱並回傳。此函式確保它不會與圖形上的現有屬性衝突。
注意
保證此 API 的向後相容性。
- 返回類型
- getattr(attr, attr_val, parameter_proxy_cache)[source][source]¶
此方法指定當我們在呼叫
nn.Module
實例時呼叫 getattr 時,此Tracer
的行為。預設行為是回傳屬性的 proxy 值。它還將 proxy 值儲存在
parameter_proxy_cache
中,以便未來的呼叫將重複使用該 proxy,而不是建立新的 proxy。可以覆寫此方法,例如,在查詢參數時不回傳 proxy。
- 參數
attr (str) – 要查詢的屬性的名稱
attr_val (Any) – 屬性的值
parameter_proxy_cache (Dict[str, Any]) – 從屬性名稱到 proxy 的快取
- 返回
從 getattr 呼叫回傳的值。
警告
此 API 是實驗性的,不具有回溯相容性。
- is_leaf_module(m, module_qualified_name)[source][source]¶
一種指定給定的
nn.Module
是否為「葉節點 (leaf)」module 的方法。葉節點 modules 是 IR 中出現的最小單位,由
call_module
呼叫引用。預設情況下,PyTorch 標準函式庫命名空間 (torch.nn) 中的 Modules 是葉節點 modules。除非透過此參數另行指定,否則所有其他 modules 都會被追蹤,並且記錄其組成的操作。- 參數
m (Module) – 正在查詢的 module
module_qualified_name (str) – 此 module 的根路徑。例如,如果您有一個 module 階層,其中子 module
foo
包含子 modulebar
,而bar
包含子 modulebaz
,則該 module 將以限定名稱foo.bar.baz
出現在此處。
- 返回類型
注意
保證此 API 的向後相容性。
- iter(obj)[source]¶
- 當 proxy 物件被迭代時呼叫,例如
當用於控制流程時。通常我們不知道該怎麼做,因為我們不知道 proxy 的值,但自訂的追蹤器可以使用 create_node 將更多資訊附加到圖形節點,並且可以選擇傳回一個迭代器。
注意
保證此 API 的向後相容性。
- 返回類型
- keys(obj)[原始碼]¶
- 當 proxy 物件呼叫 keys() 方法時呼叫。
當在 proxy 上呼叫 ** 時會發生這種情況。如果 ** 要在您的自訂追蹤器中運作,則應傳回一個迭代器。
注意
保證此 API 的向後相容性。
- 返回類型
- path_of_module(mod)[原始碼][原始碼]¶
輔助方法,用於尋找
mod
在root
的 Module 階層中的合格名稱。 例如,如果root
有一個名為foo
的子模組,而foo
有一個名為bar
的子模組,則將bar
傳遞到此函數將傳回字串 “foo.bar”。注意
保證此 API 的向後相容性。
- to_bool(obj)[原始碼]¶
- 當 proxy 物件被轉換為布林值時呼叫,例如
當用於控制流程時。通常我們不知道該怎麼做,因為我們不知道 proxy 的值,但自訂的追蹤器可以使用 create_node 將更多資訊附加到圖形節點,並且可以選擇傳回一個值。
注意
保證此 API 的向後相容性。
- 返回類型
- class torch.fx.Proxy(node, tracer=None)[原始碼][原始碼]¶
Proxy
物件是Node
封裝器,它們在符號追蹤期間流經程式,並將它們觸及的所有操作(torch
函數呼叫、方法呼叫、運算子)記錄到不斷增長的 FX Graph 中。如果您正在進行圖形轉換,您可以將您自己的
Proxy
方法包裝在原始Node
周圍,以便您可以使用重載的運算子將其他內容新增到Graph
。Proxy
物件無法迭代。換句話說,如果Proxy
用於迴圈中或作為*args
/**kwargs
函數引數,則符號追蹤器會擲回錯誤。有兩種主要解決方法:1. 將無法追蹤的邏輯分解為頂層函數,並在其上使用
fx.wrap
。2. 如果控制流程是靜態的(即,迴圈行程計數基於某些超參數),則可以將程式碼保留在其原始位置,並重構為類似以下內容:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有關 Proxy 內部結構的更詳細說明,請查看 torch/fx/README.md 中的“Proxy”部分
注意
保證此 API 的向後相容性。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[原始碼][原始碼]¶
Interpreter 逐個節點執行 FX 圖。 這種模式對於許多事情都很有用,包括編寫程式碼轉換以及分析傳遞。
可以覆寫 Interpreter 類別中的方法以自訂執行行為。 呼叫階層中可覆寫方法的對應
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
範例
假設我們想要交換所有
torch.neg
的實例與torch.sigmoid
,反之亦然 (包含它們的Tensor
方法等效項)。我們可以像這樣繼承 Interpreter:class NegSigmSwapInterpreter(Interpreter): def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 參數
module (torch.nn.Module) – 要執行的模組。
garbage_collect_values (bool) – 是否在 Module 執行過程中,於數值最後一次使用後將其刪除。這可確保執行期間的最佳記憶體使用量。可以停用此功能,例如,透過查看
Interpreter.env
屬性來檢查執行中的所有中間數值。graph (Optional[Graph]) – 如果傳入,interpreter 將執行此圖,而不是 module.graph,並使用提供的 module 引數來滿足任何狀態請求。
注意
保證此 API 的向後相容性。
- boxed_run(args_list)[source][source]¶
透過解釋來執行 module 並傳回結果。這使用 “boxed” 呼叫慣例,您傳遞一個引數列表,該列表將由 interpreter 清除。這確保了輸入張量會及時釋放。
注意
保證此 API 的向後相容性。
- call_function(target, args, kwargs)[source][source]¶
執行
call_function
節點並傳回結果。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回類型
- 傳回
Any: 函式調用傳回的值
注意
保證此 API 的向後相容性。
- call_method(target, args, kwargs)[source][source]¶
執行
call_method
節點並傳回結果。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回類型
- 傳回
Any: 方法調用傳回的值
注意
保證此 API 的向後相容性。
- call_module(target, args, kwargs)[source][source]¶
執行
call_module
節點並傳回結果。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回類型
- 傳回
Any: 模組調用傳回的值
注意
保證此 API 的向後相容性。
- fetch_args_kwargs_from_env(n)[source][source]¶
從目前的執行環境中,提取節點
n
的args
和kwargs
的具體值。- 參數
n (Node) – 應提取
args
和kwargs
的節點。- 返回
args
和kwargs
具有n
的具體值。- 返回類型
Tuple[Tuple, Dict]
注意
保證此 API 的向後相容性。
- fetch_attr(target)[source][source]¶
從
self.module
的Module
階層中提取屬性。- 參數
target (str) – 要提取的屬性的完整名稱
- 返回
屬性的值。
- 返回類型
Any
注意
保證此 API 的向後相容性。
- get_attr(target, args, kwargs)[source][source]¶
執行
get_attr
節點。將從self.module
的Module
階層中檢索屬性值。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回
檢索的屬性的值
- 返回類型
Any
注意
保證此 API 的向後相容性。
- map_nodes_to_values(args, n)[source][source]¶
遞迴地遍歷
args
,並在目前的執行環境中查找每個Node
的具體值。- 參數
args (Argument) – 在其中查找具體值的資料結構
n (Node) –
args
所屬的節點。這僅用於錯誤報告。
- 返回類型
Optional[Union[Tuple[Optional[Union[Tuple[Argument, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], …], Sequence[Optional[Union[Tuple[Argument, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[str, Optional[Union[Tuple[Argument, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
保證此 API 的向後相容性。
- output(target, args, kwargs)[source][source]¶
執行一個
output
節點。這實際上只是檢索output
節點所引用的值並將其返回。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回
輸出節點所引用的返回值
- 返回類型
Any
注意
保證此 API 的向後相容性。
- placeholder(target, args, kwargs)[source][source]¶
執行一個
placeholder
節點。 請注意,這是具狀態的:Interpreter
維護一個內部迭代器,用於迭代傳遞給run
的參數,並且此方法返回該迭代器的 next()。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回
檢索到的參數值。
- 返回類型
Any
注意
保證此 API 的向後相容性。
- class torch.fx.Transformer(module)[source][source]¶
Transformer
是一種特殊類型的直譯器,可產生新的Module
。 它公開了一個transform()
方法,該方法傳回轉換後的Module
。Transformer
不需要引數來執行,如同Interpreter
。Transformer
完全以符號方式工作。範例
假設我們要將所有
torch.neg
的實例與torch.sigmoid
進行交換,反之亦然(包括它們的Tensor
方法等效項)。 我們可以像這樣對Transformer
進行子類別化class NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 參數
module (GraphModule) – 要轉換的
Module
。
注意
保證此 API 的向後相容性。
- get_attr(target, args, kwargs)[原始碼][原始碼]¶
執行一個
get_attr
節點。在Transformer
中,這會被覆寫以插入一個新的get_attr
節點到輸出圖形中。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回類型
注意
保證此 API 的向後相容性。
- placeholder(target, args, kwargs)[原始碼][原始碼]¶
執行一個
placeholder
節點。在Transformer
中,這會被覆寫以插入一個新的placeholder
到輸出圖形中。- 參數
target (Target) – 此節點的呼叫目標。 有關語義的詳細資訊,請參閱Node
args (Tuple) – 此調用的位置引數 Tuple
kwargs (Dict) – 此調用的關鍵字引數 Dict
- 返回類型
注意
保證此 API 的向後相容性。
- torch.fx.replace_pattern(gm, pattern, replacement)[原始碼][原始碼]¶
匹配 GraphModule (
gm
) 的 Graph 中所有可能的非重疊運算子及其資料依賴項集合(pattern
),然後將每個匹配的子圖替換為另一個子圖(replacement
)。- 參數
gm (GraphModule) – 包裹要操作的 Graph 的 GraphModule
pattern (Union[Callable, GraphModule]) – 要在
gm
中匹配以進行替換的子圖replacement (Union[Callable, GraphModule]) – 用於替換
pattern
的子圖
- 返回
一個
Match
物件的列表,表示pattern
在原始圖形中匹配到的位置。如果沒有匹配項,則列表為空。Match
定義為class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回類型
List[Match]
範例
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上面的程式碼會先在
traced_module
的forward
方法中匹配pattern
。模式匹配是基於 use-def 關係,而不是節點名稱。例如,如果您在pattern
中有p = torch.cat([a, b])
,您可以匹配原始forward
函式中的m = torch.cat([a, b])
,儘管變數名稱不同(p
vsm
)。pattern
中的return
語句僅基於其值進行匹配;它可能匹配或不匹配到較大圖形中的return
語句。換句話說,該模式不一定要延伸到較大圖形的末尾。當模式匹配時,它將從較大的函式中移除,並替換為
replacement
。如果較大的函式中有多個pattern
的匹配項,則每個非重疊的匹配項都將被替換。如果發生匹配重疊,則將替換重疊匹配項集合中找到的第一個匹配項。(“第一”在這裡定義為節點的 use-def 關係拓撲排序中的第一個。在大多數情況下,第一個節點是直接出現在self
之後的參數,而最後一個節點是函式返回的任何內容。)一個重要的注意事項是
pattern
Callable 的參數必須在 Callable 本身中使用,並且replacement
Callable 的參數必須與模式匹配。第一個規則解釋了為什麼在上面的程式碼區塊中,forward
函式具有參數x, w1, w2
,但pattern
函式僅具有參數w1, w2
。pattern
沒有使用x
,因此不應將x
指定為參數。作為第二個規則的範例,考慮替換def pattern(x, y): return torch.neg(x) + torch.relu(y)
為
def replacement(x, y): return torch.relu(x)
在這種情況下,
replacement
需要與pattern
相同數量的參數(x
和y
),即使參數y
沒有在replacement
中使用。在呼叫
subgraph_rewriter.replace_pattern
後,產生的 Python 程式碼如下所示def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
保證此 API 的向後相容性。