快速鍵

Dynamo 概觀

在閱讀本節之前,請閱讀torch.compiler

TorchDynamo(或簡稱 Dynamo)是一種 Python 層級的即時 (JIT) 編譯器,旨在使未修改的 PyTorch 程式執行得更快。 Dynamo 鉤入 CPython 中的框架評估 API (PEP 523) 以在 Python 位元組碼執行之前動態修改它。 它會重寫 Python 位元組碼,將 PyTorch 運算序列提取到 FX 圖表 中,然後使用可自訂的後端進行編譯。 它透過位元組碼分析建立此 FX 圖表,並且設計為將 Python 執行與已編譯的後端混合,以獲得兩全其美的優勢 — 可用性和效能。

Dynamo 使使用不同的編譯器後端進行實驗變得容易,只需一行裝飾器 torch._dynamo.optimize() 即可使 PyTorch 程式碼執行得更快,torch.compile() 為了方便起見,將其包裝起來。

下圖示範了 PyTorch 在使用 torch.compile 和不使用時的工作方式

_images/TorchDynamo.png

TorchInductorDynamo Graph 支援的後端之一,可將圖轉換為用於 GPU 的 Triton 或用於 CPU 的 C++/OpenMP。我們有一個訓練效能儀表板,提供不同訓練後端的效能比較。您可以在 PyTorch dev-discuss 上的 TorchInductor 文章中閱讀更多內容。

如需深入了解,請閱讀以下章節、觀看深入剖析影片,並查看 dev-discuss 主題。

Dynamo 內部原理

作者: Jason AnselKaichao You

本節將介紹 Dynamo 的一些內部原理,並示範 Dynamo 的運作方式。

什麼是 guard?

Dynamo 以即時方式運作,並根據動態屬性專門化圖。以下是如何使用 Dynamo 的基本範例。可以使用 torchdynamo.optimize 裝飾函數或方法,以啟用 Dynamo 最佳化。

from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable

@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

例如,上面的第一個圖具有以下 guards

GUARDS:
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256)
check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])

如果任何一個 guard 失敗,則將重新捕獲並重新編譯該圖。其中有趣的 guard 是 check_tensor,它檢查以下 torch.Tensor 屬性

  • Tensor 的 Python 類別 (tensor 子類別化等)

  • dtype

  • device

  • requires_grad

  • dispatch_key (套用了執行緒本機包含/排除)

  • ndim

  • sizes*

  • strides*

完整專門化模式允許後端編譯器假設一個完全靜態的圖。不幸的是,大多數後端都需要這樣。當不處於動態形狀模式時,傳回動態形狀的運算子將觸發圖中斷。

Dynamo 在做什麼?

如果您想更好地了解 Dynamo 在做什麼,可以使用以下命令運行您的程式碼

TORCH_LOGS="+dynamo,guards,bytecode"

如果您不熟悉 Python 位元組碼,您可以新增一個反編譯器 hook,將位元組碼反編譯為人類可讀的原始碼。一個可用的工具是 depyf。如果您尚未安裝 depyf,請執行 pip install depyf。然後,新增以下程式碼以在您運行任何程式碼之前安裝反編譯 hook。

import depyf
depyf.install()

此程式碼會觸發有用 (但會產生大量垃圾訊息) 的列印輸出。

例如,toy_example 中第一個圖的列印輸出為

__compiled_fn_0 <eval_with_key>.1
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f9ca082f8a0>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}

ORIGINAL BYTECODE toy_example example.py line 12
 14           0 LOAD_FAST                0 (a)
              2 LOAD_GLOBAL              0 (torch)
              4 LOAD_METHOD              1 (abs)
              6 LOAD_FAST                0 (a)
              8 CALL_METHOD              1
             10 LOAD_CONST               1 (1)
             12 BINARY_ADD
             14 BINARY_TRUE_DIVIDE
             16 STORE_FAST               2 (x)

 15          18 LOAD_FAST                1 (b)
             20 LOAD_METHOD              2 (sum)
             22 CALL_METHOD              0
             24 LOAD_CONST               2 (0)
             26 COMPARE_OP               0 (<)
             28 POP_JUMP_IF_FALSE       19 (to 38)

 16          30 LOAD_FAST                1 (b)
             32 LOAD_CONST               3 (-1)
             34 BINARY_MULTIPLY
             36 STORE_FAST               1 (b)

 17     >>   38 LOAD_FAST                2 (x)
             40 LOAD_FAST                1 (b)
             42 BINARY_MULTIPLY
             44 RETURN_VALUE


MODIFIED BYTECODE toy_example example.py line 12
 12           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                1 (b)
              6 CALL_FUNCTION            2
              8 UNPACK_SEQUENCE          2
             10 STORE_FAST               2 (x)
             12 POP_JUMP_IF_FALSE       12 (to 24)
             14 LOAD_GLOBAL              4 (__resume_at_30_1)
             16 LOAD_FAST                1 (b)
             18 LOAD_FAST                2 (x)
             20 CALL_FUNCTION            2
             22 RETURN_VALUE
        >>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
             26 LOAD_FAST                1 (b)
             28 LOAD_FAST                2 (x)
             30 CALL_FUNCTION            2
             32 RETURN_VALUE


possible source code:
def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        return __resume_at_30_1(b, x)
    return __resume_at_38_2(b, x)

If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues.

在頂部,您可以看到 FX 圖。接下來,您會看到該函數的原始位元組碼,然後是 Dynamo 產生的修改後的位元組碼,以及用於參考的反編譯原始碼。最後,您會看到我們上面介紹的 guards。

在修改後的位元組碼中,__compiled_fn_0my_compiler() (已編譯的圖) 的傳回值。__resume_at_30_1__resume_at_38_2 都是產生的延續函數,在圖中斷後 (在位元組碼偏移量 30 和 38 處) 恢復執行。這些函數都採用以下形式

__resume_at_<offset>:
    ... restore stack state if needed ...
    JUMP_ABSOLUTE <offset> into toy_example
    ... original bytecode of toy_example ...

透過產生此 resume_at 函數,我們強制該函數的其餘部分在一個新的 Python 框架中執行,該框架會遞迴地觸發 Dynamo 以在執行首次到達該點時重新開始其捕獲。

如何檢查 Dynamo 產生的 artifacts?

若要檢查 Dynamo 產生的 artifacts,有一個 API torch._dynamo.eval_frame._debug_get_cache_entry_list,可以從函數的 __code__ 物件中擷取已編譯的程式碼和 guards。一個已編譯的函數可以有多個快取項目,每個快取項目都包含一個產生的函數來檢查 guards,以及一個 types.CodeType 物件來保存滿足 guarding 條件時要執行的程式碼。

from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(guard)
dis.dis(code)

如果您了解 Python 位元組碼,您就可以理解上面的輸出。

對於 guard 函數,不需要檢查位元組碼。我們可以直譯存取其 guarding 條件

for code_part in guard.code_parts:
    print(code_part)

輸出為

___guarded_code.valid
___check_global_state()
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528)
___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names)

只有當所有條件都滿足時,guard 函數才會傳回 true,並且執行已編譯的程式碼。

對於已編譯的程式碼,我們無法直譯存取其原始碼,而必須反編譯它。

from depyf import decompile
print(decompile(code))

輸出為

def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        return __resume_at_30_1(b, x)
    return __resume_at_38_2(b, x)

程式碼中引用的一些名稱為

  • 已編譯的函數,儲存在包含原始函數 toy_example 的模組的全域命名空間中。這些包括諸如 __compiled_fn_0 / __resume_at_30_1 / __resume_at_38_2 之類的名稱。

  • 用於檢查 guards 的 closure 變數。這些名稱可以從 guard.__code__.co_freevars 存取,並且這些值儲存在 guard.__closure__ 中。這些包括諸如 ___guarded_code / ___is_grad_enabled / ___are_deterministic_algorithms_enabled / ___is_torch_function_enabled / utils_device / ___check_tensors / tensor_check_names 之類的名稱。

  • guard 函式的引數 L。這是一個字典,將 toy_example 的引數名稱對應到它們的值。只有在呼叫函式時,也就是 frame evaluation API 開始作用時,這個引數才可用。簡而言之,L 是一個 dict,結構為 {'a': value_a, 'b': value_b}。因此,你可以看到程式碼使用 L['a'] 來指稱輸入變數 a

graph break 出現在編譯後的 toy_example 程式碼中,我們必須使用 Python 解譯器來選擇要執行的下一個 graph。

請注意,我們傳遞了一個簡單的 my_compiler 函式作為後端編譯器,因此 subgraph 程式碼 __resume_at_38_2__resume_at_30_1__compiled_fn_0 仍然是 Python 程式碼。 這也可以被檢查 (請忽略函式名稱,僅使用函式簽名和函式主體程式碼)

print("source code of __compiled_fn_0:")
print(innermost_fn(__compiled_fn_0).__self__.code)
print("=" * 60)
print("source code of __resume_at_30_1:")
print(decompile(__resume_at_30_1))
print("=" * 60)
print("source code of __resume_at_38_2:")
print(decompile(__resume_at_38_2))
source code of __compiled_fn_0:

def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
    l_a_ = L_a_
    l_b_ = L_b_
    abs_1 = torch.abs(l_a_)
    add = abs_1 + 1;  abs_1 = None
    truediv = l_a_ / add;  l_a_ = add = None
    sum_1 = l_b_.sum();  l_b_ = None
    lt = sum_1 < 0;  sum_1 = None
    return (truediv, lt)

# To see more debug info, please use ``graph_module.print_readable()``
============================================================
source code of __resume_at_30_1:
def <resume in toy_example>(b, x):
    b = b * -1
    return x * b

============================================================
source code of __resume_at_38_2:
def <resume in toy_example>(b, x):
    return x * b

但是,如果我們使用其他後端,例如內建的 inductor,則 subgraph 程式碼將被編譯為 GPU 的 CUDA 核心或 CPU 的 C++ 程式碼。

總而言之,編譯後的程式碼在概念上等同於以下程式碼

def compiled_example(a, b):
    L = {'a': a, 'b': b}
    for guard, code in get_cache_entries():
        if guard(L):
            return code(a, b)
    recompile_and_add_another_cache_entry()

下圖演示了 torch.compile 如何轉換和優化使用者編寫的程式碼:它首先從使用者編寫的函式中提取計算圖,然後將這些圖編譯為優化的函式,然後將它們組裝成一個新的函式,該函式在功能上與使用者編寫的程式碼等效,但經過優化以具有良好的計算速度。

_images/flowchart.jpg

若要深入了解所有這些如何在內部實作,請參閱 Dynamo 深度解析

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題的解答

檢視資源