自訂後端¶
概述¶
torch.compile
提供了一種直接的方法,讓使用者定義自訂後端。
後端函數的合約為 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
。
後端函式可以被 TorchDynamo 呼叫,TorchDynamo 是 torch.compile
的圖追蹤元件,在追蹤 FX 圖之後,預期會回傳一個編譯過的函式,該函式等同於被追蹤的 FX 圖。回傳的可呼叫物件應與傳遞到後端的原始 torch.fx.GraphModule
的 forward
函式具有相同的合約:(*args: torch.Tensor) -> List[torch.Tensor]
。
為了讓 TorchDynamo 呼叫您的後端,請將您的後端函式作為 torch.compile
中的 backend
關鍵字參數傳入。例如:
import torch
def my_custom_backend(gm, example_inputs):
return gm.forward
def f(...):
...
f_opt = torch.compile(f, backend=my_custom_backend)
@torch.compile(backend=my_custom_backend)
def g(...):
...
更多範例請見下方。
註冊自定義後端¶
您可以使用 register_backend
修飾器來註冊您的後端,例如:
from torch._dynamo import register_backend
@register_backend
def my_compiler(gm, example_inputs):
...
除了 register_backend
修飾器之外,如果您的後端位於另一個 Python 套件中,您也可以通過 Python 套件的 entry points 來註冊您的後端,這提供了一種讓套件為另一個套件註冊外掛程式的方式。
提示
您可以在 Python 封裝文件中了解更多關於 entry_points
的資訊。
要通過 entry_points
註冊您的後端,您可以將您的後端函式添加到套件的 setup.py
檔案中的 torch_dynamo_backends
entry point 群組,例如:
...
setup(
...
'torch_dynamo_backends': [
'my_compiler = your_module.submodule:my_compiler',
]
...
)
請將 =
之前的 my_compiler
替換為您的後端的名稱,並將 =
之後的部分替換為您的後端函式的模組和函式名稱。在安裝套件後,entry point 將被添加到您的 Python 環境中。當您呼叫 torch.compile(model, backend="my_compiler")
時,PyTorch 將首先搜尋已使用 register_backend
註冊的名稱為 my_compiler
的後端。如果找不到,它將繼續在通過 entry_points
註冊的所有後端中搜尋。
註冊有兩個目的
您可以將包含您的後端函式名稱的字串傳遞給
torch.compile
,而不是函式本身,例如:torch.compile(model, backend="my_compiler")
。這是與 minifier 搭配使用所必需的。來自 minifier 的任何產生的程式碼都必須呼叫註冊您的後端函式的程式碼,通常是通過
import
語句。
AOTAutograd 之後的自定義後端¶
可以定義由 AOTAutograd 而非 TorchDynamo 呼叫的自定義後端。 這對於以下兩個主要原因很有用
使用者可以定義支援模型訓練的後端,因為 AOTAutograd 可以產生用於編譯的向後圖。
AOTAutograd 產生由 core Aten ops 組成的 FX 圖。 因此,自定義後端只需要支援 core Aten opset,這是一個比整個 torch/Aten opset 小得多的 opset。
使用 torch._dynamo.backends.common.aot_autograd
包裝您的後端,並像以前一樣使用帶有 backend
關鍵字參數的 torch.compile
。 由 aot_autograd
包裝的後端函式應與以前具有相同的合約。
後端函式通過 fw_compiler
(正向編譯器)或 bw_compiler
(反向編譯器)關鍵字參數傳遞到 aot_autograd
。 如果未指定 bw_compiler
,則反向編譯函式預設為正向編譯函式。
一個需要注意的是,AOTAutograd 要求後端回傳的編譯函式必須是 "boxed"。 這可以通過使用 functorch.compile.make_boxed_func
包裝編譯函式來完成。
例如:
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
def my_compiler(gm, example_inputs):
return make_boxed_func(gm.forward)
my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler
model_opt = torch.compile(model, backend=my_backend)
範例¶
偵錯後端¶
如果您想更好地了解編譯期間發生的情況,您可以建立一個自定義編譯器(在本節中稱為後端),它將 pretty print 從 Dynamo 的位元組碼分析中提取的 fx GraphModule
並回傳一個 forward()
可呼叫物件。
例如:
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
執行上述範例會產生以下輸出
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
這對於 torch.nn.Module
也有效,如下所示
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
讓我們看看另一個具有控制流程的範例
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
執行此範例會產生以下輸出
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
快速後端¶
整合提供卓越效能的自定義後端也很容易,我們將整合一個使用 optimize_for_inference 的真實後端
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
然後您應該能夠使用以下程式碼優化任何現有程式碼
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
可組合後端¶
TorchDynamo 包含許多後端,可以使用 torch._dynamo.list_backends()
列出。 您可以使用以下程式碼將這些後端組合在一起
from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
try:
trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
except Exception:
pass
# first backend failed, try something else...
try:
inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
if inductor_compiled is not None:
return inductor_compiled
except Exception:
pass
return gm.forward