• 文件 >
  • 自定義編譯器 Pass 與分割器
捷徑

自定義編譯器 Pass 與分割器

Pass

Pass大致可以分為幾個軸

軸 A

  1. 建立一對 X 映射 (例如,分解)

  2. 建立多對一映射 (例如,融合)

軸 B

  1. 執行正向迭代 (例如,形狀傳播)

  2. 執行反向迭代 (例如,無效程式碼消除)

軸 C

  1. 取決於本機節點資訊 (例如,out-variant 轉換)

  2. 取決於全域圖形資訊 (例如,記憶體規劃)

我們對這些用例頻率的預測是

  1. A.1, B.1, C.1

  2. A.2

  3. B.2, C.2

等級 1

對於第一級的使用案例(建立一對多的映射、執行正向迭代,以及查看本地節點資訊),我們可以利用一個名為 ExportPass 的輔助類別。這是一種基於直譯器的方式,我們會執行每個節點並重新建立圖,但會套用指定的轉換。這讓我們可以保留 IR 規範,確保在 pass 中建立的所有節點都符合 IR 規範,包括確保堆疊追蹤、FakeTensor 值和 torch.nn.Module 階層等中繼資料會根據所做的轉換來保留和更新。

要實作此 pass,我們可以建立 ExportPass 的子類別,並實作公開的函式。當使用 graph module 呼叫時,它會執行 graph module 並建立一個新的 graph,其中包含 pass 指定的變更。這表示傳入的 graph module 必須可在 CPU 上執行,並且此不變性會在 pass 執行後維持不變。

一對一 Pass

對於一對一映射的範例,如果我們想要將 op A 替換為另一個 op B,我們可以執行給定的 fx.GraphModule,並且每次看到 op A 時,都傳回 op B。

考慮以下範例

class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
    """
    relu_ is the in-place version. Replace it with relu, which is the
    out-of-place version
    """

    def call_operator(self, op, args, kwargs, meta):
        if op != torch.ops.aten.relu_.default:
            return super().call_operator(op, args, kwargs, meta)
        return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)

# To create a pass
replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
# To run a pass
new_graph_module = replace_pass(graph_module).graph_module

super().call_operator(op, args, kwargs, meta) 呼叫會建立一個 call_function FX 節點,並傳回使用給定引數執行運算符的結果。

一對多 Pass

如果我們想要進行一對多的映射,例如將 op A 替換為另外兩個 ops B 和 C,我們會進行兩次 super().call_operator 呼叫以建立兩個 FX 節點,一個使用 op B,另一個使用 op C,並傳回執行 op C 的結果。

例如

class ReplaceAddWithMulSub(ExportPass):
    """
    Original:
        def f(x, y):
            return x + y

    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_operator(self, op, args, kwargs, meta):
        if op != torch.ops.aten.add.default:
            return super().call_operator(op, args, kwargs, meta)

        x, y = args

        mul_res = super().call_operator(
            torch.ops.aten.mul.default,
            args,
            {},
            meta
        )

        return super().call_operator(
            torch.ops.aten.sub.default,
            (mul_res, y),
            {},
            meta
        )

一對零 Pass

如果我們想要移除一個 op,我們可以只傳回傳遞到函式中的值

class RemoveDetachPass(ExportPass):
    def call_operator(self, op, args, kwargs, meta):
        if op not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_operator(op, args, kwargs, meta)

        assert len(args) == 1
        return args[0]

利用本地資訊

利用本地節點資訊的一個例子是,如果我們想要將 graph 中的所有純量轉換為張量,我們可以執行給定的 fx.GraphModule,並且對於包含純量的每個引數,我們都會將其轉換為張量。它可能看起來像這樣

def args_map(op, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()

    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)

    # Update each argument in the schema
    for i, schema in enumerate(self.op._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)

class ScalarToTensorPass(ExportPass):
    def call_operator(self, op, args, kwargs):
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )

        args, kwargs = args_map(op, try_coerce, args, kwargs)
        return super().call_operator(op, args, kwargs)

第二級

為了建立多對一的映射,我們可以利用 FX 的 subgraph rewriter。給定一個 pattern,它會建立一個與模式相符的運算符子圖,然後將每個相符的子圖替換為 replacement

注意

This is an inplace operation.

patternreplacement 輸入必須是使用與您要匹配的 EXIR 圖中使用的相同 ops (ATen ops) 寫成的可呼叫函式,以便 subgraph rewriter 可以在圖中找到正確的模式。模式/替換可呼叫對象的輸入將被視為萬用字元。

考慮以下範例

from torch.fx import subgraph_rewriter

def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x

    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)

subgraph rewriter 傳回一個 ReplacedPatterns 的清單

@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]

注意

The nodes created by the subgraph rewriter will not have the metadata that
is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).

第三級

對於建立 pass 的第三種方式,我們可以利用最基本的 PassBase。為了建立一個 pass,我們可以繼承這個類別並實作具有 pass 內容的函式 call。此外,我們可以實作函式 requiresensures,它們將在函式 call 之前和之後呼叫。請注意,這些函式也可以在 ExportPass 中被覆寫。要在 graph module 上執行 pass,我們可以將 graph module 直接傳遞給類別的實例。

考慮以下範例

class ReplaceAddPass(PassBase):

    def __init__(self, replace_op):
        self.replace_op = replace_op

    def call(self, graph_module):
        for node in gm.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                node.target = self.replace_op

    # Optional to implement, will be called before call()
    def requires(self, graph_module) -> None:
        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                return
        raise ValueError("No torch.add ops!")

    # Optional to implement, will be called after call()
    def ensures(self, graph_module: torch.fx.GraphModule) -> None:
        pass

# To create a pass
replace_add_with_div = ReplaceAddPass(torch.div)
# To run a pass
replace_add_with_div(graph_module)

Pass 管理器

PassManager 是一個用於在給定 graph module 上執行多個 passes 的類別。當初始化一個 PassManager 實例時,我們傳入一個我們想要執行的 passes 清單,並設定一些旗標。要在 graph module 上執行 passes 集合,我們可以將 graph module 直接傳遞給 PassManager 實例。

範例

from executorch.exir.pass_manager import PassManager

pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)

為了新增一組在每個 pass 之後執行的常見檢查,我們可以呼叫函式 set_checks(check: Callable),它接受一個可呼叫函式作為輸入。如果設定了 run_checks_after_each_pass 旗標,則在每次在 graph module 上執行 pass 後,都會呼叫 check

範例

pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])

def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")

pm.add_checks(check_div_target)

pm(graph_module)    # raises ValueError after replace_div_with_mul pass

分割器

有一些常見的基於 FX 圖的分割器,我們可以用於分割圖。但是,這些分割器不一定會產生符合 IR 規範的圖,因此在使用它們時要小心。

子圖匹配器

為了在圖中找到與特定模式匹配的子圖,我們可以利用 FX 的 SubgraphMatcher

類別屬性

  • pattern (Graph):目標匹配模式。圖中的佔位符節點將在匹配時被視為萬用字元。

  • match_output (bool):如果為 True,則模式圖中的輸出節點將被視為目標模式的一部分。如果為 False,則在匹配期間忽略輸出節點。

  • match_placeholder (bool):如果為 True,則模式圖中的佔位符節點將被視為目標模式的一部分。如果為 False,則佔位符節點將用作萬用字元。

  • remove_overlapping_matches (bool):如果為 True,在重疊匹配的情況下,只會傳回第一個匹配項。

  • ignore_literals (bool):如果為 True,則不會檢查文字是否相等,而是將它們視為萬用字元。

考慮以下範例

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)

large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph

class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)

pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph

subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)

match 函數會回傳 InternalMatch 的列表

@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)

基於能力的分割器

為了找到支援特定不變量的節點的最大子圖,我們可以利用 FX 的 CapabilityBasedPartitioner

類別屬性

  • graph_module (torch.fx.GraphModule):我們要分割的圖模組。

  • operator_support (OperatorSupportBase):用於判斷圖中節點是否在分割區中受支援的物件。

  • allows_single_node_partition (bool):如果為 True,允許形成單一節點分割區。

  • non_compute_ops (Optional[Sequence[str]]):一組被認為是“非計算”的操作 (例如 torch.ops.aten.view_operator.getitem),因此分割器不會建立只包含這些非計算操作的圖。

  • allowed_single_node_partition_ops (Optional[Sequence[str]]):一組允許在單一節點分割區中的操作。

OperatorSupportBase 類別被分割器用於判斷圖中的特定節點是否屬於分割區。這是透過覆寫 is_node_supported 函數來完成的。您可以透過使用 chain(如果任何 OperatorSupportBase 回傳 False,則回傳 False)和 any_chain(如果任何 OperatorSupportBase 回傳 True,則回傳 True)來鏈接多個 OperatorSuppportBase

考慮以下範例

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]

capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)

# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()

如果您查看基於能力的分割器,您也可能會找到一個 fuse_partition 函數,該函數將回傳一個修改後的圖,其中分割區作為子模組,並透過 call_module 節點在頂層圖中呼叫這些子模組。然而,這不符合 IR 規範,因為我們不允許 call_module 節點。

組合

我們也提供了一個組合輔助函數:generate_pattern_op_partitions

參數

  • graph_module (fx.GraphModule):我們要分割的模組

  • patterns (List[torch.fx.Graph]):torch.fx.Graph 形式的模式列表。這些圖可以透過從 exir.capture 取得的 GraphModule 中的 graph 欄位(推薦),或透過符號追蹤(可能不會產生精確的邊緣方言圖),或透過手動製作圖模組來獲得。

  • op_support (OperatorSupportBase):可以透過以下方式建立的 OperatorSupportBase

    • 直接對其進行子類別化並實作 is_node_supported()

    • 取得 create_op_support() 的結果

    • 取得 create_pattern_support() 的結果

    • 多個 OperatorSupportBase 類別與 chain()any_chain() 鏈接在一起

回傳

  • 包含節點的分區列表(最大的可能子圖),這些節點受到給定的 OperatorSupportBase 物件和給定的模式圖的聯集支援。

來源分割器

對於更複雜的使用案例,其中使用者想要基於更高等級的模組 (torch.nn.Lineartorch.nn.functional.Linear) 進行分割,這些模組現在被分解為其運算子 (aten.permute, aten.addmm),我們有以下 輔助函數

get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]

參數

  • graph:我們要分割的圖

  • wanted_sources:從這個來源分解的節點的來源列表。這可以是一個函數 (例如 torch.nn.functional.linear) 或一個葉模組類型 (例如 torch.nn.Linear)

回傳

  • 字典,將來源 (例如 torch.nn.modules.linear.Linear) 映射到 SourcePartitions 的列表,這些列表對應於從該類型的模組中扁平化的節點列表。

@dataclass
class SourcePartition():
    # Nodes in a particular partition
    nodes: List[Node]
    # Module type
    module_type: Type
    # Nodes in the graph that are needed as inputs to the partition
    input_nodes: List[Node] = field(default_factory=list)
    # Nodes in the partition that are being used by nodes outside of the partition
    output_nodes: List[Node] = field(default_factory=list)
    # Parameters that are being used
    params: List[str] = field(default_factory=list)

範例

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(3, 3)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(3, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

inputs = (torch.randn(3, 3),)
edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
print(edge_graph)
"""
graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
    %permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
    %permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
    return [addmm_default_2]
"""

module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
print(module_partitions)
"""
{<class 'torch.nn.modules.linear.Linear'>: [
    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],

 <class 'torch.nn.modules.activation.ReLU'>: [
    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
"""

文件

Access comprehensive developer documentation for PyTorch

View Docs

教學

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources