自定義編譯器 Pass 與分割器¶
Pass¶
Pass大致可以分為幾個軸
軸 A
建立一對 X 映射 (例如,分解)
建立多對一映射 (例如,融合)
軸 B
執行正向迭代 (例如,形狀傳播)
執行反向迭代 (例如,無效程式碼消除)
軸 C
取決於本機節點資訊 (例如,out-variant 轉換)
取決於全域圖形資訊 (例如,記憶體規劃)
我們對這些用例頻率的預測是
A.1, B.1, C.1
A.2
B.2, C.2
等級 1¶
對於第一級的使用案例(建立一對多的映射、執行正向迭代,以及查看本地節點資訊),我們可以利用一個名為 ExportPass
的輔助類別。這是一種基於直譯器的方式,我們會執行每個節點並重新建立圖,但會套用指定的轉換。這讓我們可以保留 IR 規範,確保在 pass 中建立的所有節點都符合 IR 規範,包括確保堆疊追蹤、FakeTensor 值和 torch.nn.Module 階層等中繼資料會根據所做的轉換來保留和更新。
要實作此 pass,我們可以建立 ExportPass
的子類別,並實作公開的函式。當使用 graph module 呼叫時,它會執行 graph module 並建立一個新的 graph,其中包含 pass 指定的變更。這表示傳入的 graph module 必須可在 CPU 上執行,並且此不變性會在 pass 執行後維持不變。
一對一 Pass¶
對於一對一映射的範例,如果我們想要將 op A 替換為另一個 op B,我們可以執行給定的 fx.GraphModule
,並且每次看到 op A 時,都傳回 op B。
考慮以下範例
class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
"""
relu_ is the in-place version. Replace it with relu, which is the
out-of-place version
"""
def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.relu_.default:
return super().call_operator(op, args, kwargs, meta)
return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)
# To create a pass
replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
# To run a pass
new_graph_module = replace_pass(graph_module).graph_module
super().call_operator(op, args, kwargs, meta)
呼叫會建立一個 call_function
FX 節點,並傳回使用給定引數執行運算符的結果。
一對多 Pass¶
如果我們想要進行一對多的映射,例如將 op A 替換為另外兩個 ops B 和 C,我們會進行兩次 super().call_operator
呼叫以建立兩個 FX 節點,一個使用 op B,另一個使用 op C,並傳回執行 op C 的結果。
例如
class ReplaceAddWithMulSub(ExportPass):
"""
Original:
def f(x, y):
return x + y
After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.add.default:
return super().call_operator(op, args, kwargs, meta)
x, y = args
mul_res = super().call_operator(
torch.ops.aten.mul.default,
args,
{},
meta
)
return super().call_operator(
torch.ops.aten.sub.default,
(mul_res, y),
{},
meta
)
一對零 Pass¶
如果我們想要移除一個 op,我們可以只傳回傳遞到函式中的值
class RemoveDetachPass(ExportPass):
def call_operator(self, op, args, kwargs, meta):
if op not in (
torch.ops.aten.detach.default,
torch.ops.aten.detach_copy.default,
):
return super().call_operator(op, args, kwargs, meta)
assert len(args) == 1
return args[0]
利用本地資訊¶
利用本地節點資訊的一個例子是,如果我們想要將 graph 中的所有純量轉換為張量,我們可以執行給定的 fx.GraphModule
,並且對於包含純量的每個引數,我們都會將其轉換為張量。它可能看起來像這樣
def args_map(op, fn, args, kwargs):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
args = list(args)
kwargs = kwargs.copy()
# Update the argument based on the function passed
def update(key, args, schema):
args[key] = fn(args[key], schema)
# Update each argument in the schema
for i, schema in enumerate(self.op._schema.arguments):
if schema.name in kwargs:
update(schema.name, kwargs, schema)
elif not schema.kwarg_only and i < len(args):
update(i, args, schema)
class ScalarToTensorPass(ExportPass):
def call_operator(self, op, args, kwargs):
def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
else value
)
args, kwargs = args_map(op, try_coerce, args, kwargs)
return super().call_operator(op, args, kwargs)
第二級¶
為了建立多對一的映射,我們可以利用 FX 的 subgraph rewriter。給定一個 pattern
,它會建立一個與模式相符的運算符子圖,然後將每個相符的子圖替換為 replacement
。
注意
This is an inplace operation.
pattern
和 replacement
輸入必須是使用與您要匹配的 EXIR 圖中使用的相同 ops (ATen ops) 寫成的可呼叫函式,以便 subgraph rewriter 可以在圖中找到正確的模式。模式/替換可呼叫對象的輸入將被視為萬用字元。
考慮以下範例
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x
def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)
subgraph rewriter 傳回一個 ReplacedPatterns
的清單
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).
第三級¶
對於建立 pass 的第三種方式,我們可以利用最基本的 PassBase
。為了建立一個 pass,我們可以繼承這個類別並實作具有 pass 內容的函式 call
。此外,我們可以實作函式 requires
和 ensures
,它們將在函式 call
之前和之後呼叫。請注意,這些函式也可以在 ExportPass
中被覆寫。要在 graph module 上執行 pass,我們可以將 graph module 直接傳遞給類別的實例。
考慮以下範例
class ReplaceAddPass(PassBase):
def __init__(self, replace_op):
self.replace_op = replace_op
def call(self, graph_module):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.add:
node.target = self.replace_op
# Optional to implement, will be called before call()
def requires(self, graph_module) -> None:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == torch.add:
return
raise ValueError("No torch.add ops!")
# Optional to implement, will be called after call()
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
pass
# To create a pass
replace_add_with_div = ReplaceAddPass(torch.div)
# To run a pass
replace_add_with_div(graph_module)
Pass 管理器¶
PassManager
是一個用於在給定 graph module 上執行多個 passes 的類別。當初始化一個 PassManager
實例時,我們傳入一個我們想要執行的 passes 清單,並設定一些旗標。要在 graph module 上執行 passes 集合,我們可以將 graph module 直接傳遞給 PassManager
實例。
範例
from executorch.exir.pass_manager import PassManager
pm = PassManager(
passes=[replace_add_with_div, replace_div_with_mul],
run_checks_after_each_pass=True,
suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
為了新增一組在每個 pass 之後執行的常見檢查,我們可以呼叫函式 set_checks(check: Callable)
,它接受一個可呼叫函式作為輸入。如果設定了 run_checks_after_each_pass
旗標,則在每次在 graph module 上執行 pass 後,都會呼叫 check
。
範例
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target != torch.div:
raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module) # raises ValueError after replace_div_with_mul pass
分割器¶
有一些常見的基於 FX 圖的分割器,我們可以用於分割圖。但是,這些分割器不一定會產生符合 IR 規範的圖,因此在使用它們時要小心。
子圖匹配器¶
為了在圖中找到與特定模式匹配的子圖,我們可以利用 FX 的 SubgraphMatcher
。
類別屬性
pattern (Graph)
:目標匹配模式。圖中的佔位符節點將在匹配時被視為萬用字元。match_output (bool)
:如果為 True,則模式圖中的輸出節點將被視為目標模式的一部分。如果為 False,則在匹配期間忽略輸出節點。match_placeholder (bool)
:如果為 True,則模式圖中的佔位符節點將被視為目標模式的一部分。如果為 False,則佔位符節點將用作萬用字元。remove_overlapping_matches (bool)
:如果為 True,在重疊匹配的情況下,只會傳回第一個匹配項。ignore_literals (bool)
:如果為 True,則不會檢查文字是否相等,而是將它們視為萬用字元。
考慮以下範例
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight = torch.nn.Parameter(torch.ones(3, 3))
self._bias = torch.nn.Parameter(torch.ones(3, 3))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph
class PatternModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
match
函數會回傳 InternalMatch
的列表
@dataclass
class InternalMatch():
# Nodes from which the match was found
anchors: List[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# Nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
# Nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
基於能力的分割器¶
為了找到支援特定不變量的節點的最大子圖,我們可以利用 FX 的 CapabilityBasedPartitioner
。
類別屬性
graph_module (torch.fx.GraphModule)
:我們要分割的圖模組。operator_support (OperatorSupportBase)
:用於判斷圖中節點是否在分割區中受支援的物件。allows_single_node_partition (bool)
:如果為 True,允許形成單一節點分割區。non_compute_ops (Optional[Sequence[str]])
:一組被認為是“非計算”的操作 (例如torch.ops.aten.view
和_operator.getitem
),因此分割器不會建立只包含這些非計算操作的圖。allowed_single_node_partition_ops (Optional[Sequence[str]])
:一組允許在單一節點分割區中的操作。
OperatorSupportBase
類別被分割器用於判斷圖中的特定節點是否屬於分割區。這是透過覆寫 is_node_supported
函數來完成的。您可以透過使用 chain
(如果任何 OperatorSupportBase 回傳 False,則回傳 False)和 any_chain
(如果任何 OperatorSupportBase 回傳 True,則回傳 True)來鏈接多個 OperatorSuppportBase
。
考慮以下範例
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
]
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
如果您查看基於能力的分割器,您也可能會找到一個 fuse_partition
函數,該函數將回傳一個修改後的圖,其中分割區作為子模組,並透過 call_module
節點在頂層圖中呼叫這些子模組。然而,這不符合 IR 規範,因為我們不允許 call_module
節點。
組合¶
我們也提供了一個組合輔助函數:generate_pattern_op_partitions
參數
graph_module (fx.GraphModule)
:我們要分割的模組patterns (List[torch.fx.Graph])
:torch.fx.Graph 形式的模式列表。這些圖可以透過從 exir.capture 取得的 GraphModule 中的graph
欄位(推薦),或透過符號追蹤(可能不會產生精確的邊緣方言圖),或透過手動製作圖模組來獲得。op_support (OperatorSupportBase)
:可以透過以下方式建立的 OperatorSupportBase直接對其進行子類別化並實作
is_node_supported()
取得
create_op_support()
的結果取得
create_pattern_support()
的結果多個 OperatorSupportBase 類別與
chain()
或any_chain()
鏈接在一起
回傳
包含節點的分區列表(最大的可能子圖),這些節點受到給定的 OperatorSupportBase 物件和給定的模式圖的聯集支援。
來源分割器¶
對於更複雜的使用案例,其中使用者想要基於更高等級的模組 (torch.nn.Linear
或 torch.nn.functional.Linear
) 進行分割,這些模組現在被分解為其運算子 (aten.permute
, aten.addmm
),我們有以下 輔助函數
get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]
參數
graph
:我們要分割的圖wanted_sources
:從這個來源分解的節點的來源列表。這可以是一個函數 (例如torch.nn.functional.linear
) 或一個葉模組類型 (例如torch.nn.Linear
)
回傳
字典,將來源 (例如
torch.nn.modules.linear.Linear
) 映射到SourcePartitions
的列表,這些列表對應於從該類型的模組中扁平化的節點列表。
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
範例
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
inputs = (torch.randn(3, 3),)
edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
print(edge_graph)
"""
graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
"""
module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
print(module_partitions)
"""
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
"""