捷徑

後端方言

概觀

後端方言 (Backend dialect)邊緣方言 (edge dialect) 的一種特殊變體,因為它包含後端特定的節點和元數據,這些是在後端特定的圖形轉換之後產生的。後端方言是一個可選階段,只有當我們想將後端意識引入圖形時才需要。更具體地說,後端方言中的圖形可能包含僅對目標後端有意義的運算符或委託的降低模組(請參閱 委託文檔 (delegate doc))。一種使用案例是,如果我們想將運算符融合到單個運算符中,例如,將連續的 addmm + relu 融合到單個運算符 addmm_relu,我們可以在此處執行此操作。

本文檔描述瞭如何引入後端特定的運算符。

自定義運算 (custom ops) 和後端特定運算 (backend specific ops) 之間的區別:雖然自定義運算會出現在 eager 模式、ATen 方言和邊緣方言中,但後端特定運算僅由邊緣方言之後發生的 passes 引入。

何時使用

此方言允許引入不符合規範 ATen 運算符集中定義的模式,並且不顯示在上述任何方言(ATen 方言和邊緣方言)中的運算符。如果您的用例滿足以下一個或多個條件,請考慮使用後端運算符

  • 您的後端提供一個庫,該庫優化等效於子圖的特定運算符。例如,linear_relu (等效於 linear + relu) 可以在特定後端上更快地執行。

  • 需要在圖形模組降低到後端之後重新追蹤 (retrace) 它。當我們重新追蹤時,後端運算符可以轉換回原始子圖(在 ATen 方言中),而普通的自定義運算無法處理這種情況。

  • 您的後端特定運算符沒有通用的 CPU 核心,只有特定後端的核心。使用後端運算符可以通過使用原始子圖作為默認核心並保持圖形模組可運行來解決此問題。

  • 或者,如果您擔心它可能矯枉過正,並且只是想要一些更輕量級的東西,並且只需要編譯器階段的 Python 代碼,則可以使用委託 (delegate)。

API

對於運算符/子圖替換,常見的流程是

  1. 註冊一個與子圖具有相同輸入和輸出的運算符。此運算符將沒有目標特定的實現(在編譯階段也不需要),但它需要給出與子圖相同的結果。

  2. 創建一種模式,允許編譯器找到子圖並用替換項替換它。

  3. 編寫一個 pass 以將子圖替換為新運算符。

為了方便這個過程,我們提供了一個 API 來幫助減少 ExecuTorch 用戶執行這些步驟的工作量。

Pass Infra 入口點

為了將邊緣運算 (edge ops) 降低到後端運算 (backend ops),pass 將執行模式匹配以識別圖形中感興趣的邊緣運算,然後用等效的後端運算符替換它們。 有兩個 API 可以註冊這樣的 passes

  • transform()。ExportProgram 上的一個 API,允許用戶提供自定義的 passes。 請注意,這不受任何驗證器的保護,因此程序的可靠性無法保證。

  • ExecutorchBackendConfig.passes。如果在此處添加,則 pass 將成為從後端方言到 ExecutorchProgram 的降低過程的一部分。

示例:一個這樣的 pass 是 QuantFusion。 此 pass 採用“規範量化模式”,即“dequant - some_op - quant”,並將此模式融合到一個後端特定的單個運算符中,即 quantized_decomposed::some_op。 另一個更簡單的示例是 這裡,我們將 sym_size 運算符替換為 ExecuTorch 可以理解的運算符

模式綁定裝飾器 (Pattern Binding Decorator)

我們提供一個裝飾器 bind_pattern_to_op 來幫助用戶輕鬆地將其後端運算符註冊到 EXIR 中。 這個裝飾器採用

  • 一個 torch.Library 對象,它指示此後端運算符屬於哪個庫或命名空間。

  • 一個名稱或模式。 如果我們已經在 torch.Library 對象中定義了後端運算符的模式,則只需要一個名稱。 否則,如果傳入模式字符串,我們可以註冊模式。

這個裝飾器應該被添加到我們嘗試匹配(然後降低到這個後端運算)的邊緣方言的模式上。 這樣,我們將這個模式註冊為這個後端運算符的 CompositeImplicitAutograd 核心。

然後可以從 passes 訪問/使用該運算符。CompositeImplicitAutograd 核心確保

  1. 用戶無需編寫 (CPU) 可運行的核心。

  2. 確保 ExportProgram 的可追蹤性。 重新追蹤後,後端運算符將被分解為模式中使用的 ATen 運算。

示例

讓我們假設一個包含 add 和 relu 運算符的簡單程序

def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    z = x + y
    return torch.ops.aten.relu.default(z)

降低到邊緣方言後,它變為

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
    %aten_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.relu.default](args = (%aten_add_tensor,), kwargs = {})
    return (aten_relu_default,)

現在我想編寫一個 pass 來將 addrelu 合併到 add_relu 中,第一步是編寫一個模式

# In the pattern, we can use edge ops and ATen ops interchangably
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    z = torch.ops.aten.add.Tensor(x, y)
    out = torch.ops.aten.relu.default(z)
    return out

然後我們需要從融合的運算符命名空間創建一個運算符庫,然後在我們的模式上使用裝飾器

lib = Library("foo_namespace", "DEF")

@bind_pattern_to_op(lib, "add_relu(Tensor self, Tensor other) -> Tensor")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        z = torch.ops.aten.add.Tensor(x, y)
        out = torch.ops.aten.relu.default(z)
        return out

這樣,我們將該模式註冊為 add_relu 的核心,並且它已準備好在 pass 中使用。 一個簡單的 pass 如下所示

class AddReluFusionPass(ExportPass):
    def call(self, graph_module: GraphModule) -> PassResult:
        # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
        @bind_pattern_to_op(lib, "add_relu")
        def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            z = torch.ops.aten.add.Tensor(x, y)
            out = torch.ops.aten.relu.default(z)
            return out

        def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return torch.ops.foo_namespace.add_relu.default(x, y)

        subgraph_rewriter.replace_pattern(
            graph_module,
            _trace_and_lower_to_edge_ops(pattern),
            _trace_and_lower_to_edge_ops(replacement),
        )
        return PassResult(graph_module, True)

結果圖形如下所示

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %foo_namespace_add_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.foo_namespace.add_relu.default](args = (%arg0_1, %arg1_1), kwargs = {})
    return (foo_namespace_add_relu_default,)

Op Set

以下是目前使用 bind_pattern_to_op API 的後端運算符。

  • executorch_prims::add.int(SymInt a, SymInt b) -> SymInt

    • 模式:builtin.add

    • 後端:executor

  • executorch_prims::mul.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.mul

    • 後端:executor

  • executorch_prims::sub.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.sub

    • 後端:executor

  • executorch_prims::floordiv.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.floordiv

    • 後端:executor

  • executorch_prims::truediv.int(Scalar a, Scalar b) -> Scalar

    • pattern: builtin.div

    • 後端:executor

  • executorch_prims::sym_float.Scalar(Scalar a) -> Scalar

    • pattern: builtin.float

    • 後端:executor

  • executorch_prims::gt.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.gt

    • 後端:executor

  • executorch_prims::lt.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.lt

    • 後端:executor

  • executorch_prims::ge.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.ge

    • 後端:executor

  • executorch_prims::le.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.le

    • 後端:executor

  • executorch_prims::eq.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.eq

    • 後端:executor

  • executorch_prims::mod.Scalar(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.divmod

    • 後端:executor

  • executorch_prims::neg.Scalar(Scalar a) -> Scalar

    • pattern: operator.ne

    • 後端:executor

  • quantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc

    • pattern: source

    • backend: quantization

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源