捷徑

Dynamo 深入探討

TorchDynamo(或簡稱 Dynamo)是 torch.compile 中的追蹤器,而且它通常是那些令人難以置信的回溯的罪魁禍首。但是,我們不能盲目地將這些錯誤歸咎於 Dynamo。為了向使用者提供其所具有的彈性,Dynamo 被賦予了理解任何 Python 程式的艱鉅任務。特別是,Dynamo 必須在內部實作 Python 程式語言的很大一部分!

在這篇文章中,我們將從頭開始探討 Dynamo 的內部設計。我們將討論它提供的功能以及如何實現。閱讀完本文後,您將更深入地了解在使用 torch.compiled 編譯 PyTorch 程式時發生錯誤的原因,或者成功編譯但速度提升不如預期的原因。

Dynamo 簡介

在深入探討所有實現細節之前,讓我們先討論 Dynamo 的作用。

Dynamo 是一個追蹤器 (tracer)。這意味著,給定一個函數及其輸入,它會執行該函數並將指令的線性序列(沒有控制流程)記錄到一個圖中。例如,考慮以下程式:

import torch

@torch.compile
def mse(x, y):
    z = (x - y) ** 2
    return z.sum()

x = torch.randn(200)
y = torch.randn(200)
mse(x, y)

如果我們將此程式儲存到 example.py 檔案中,然後執行

TORCH_LOGS=graph_code python example.py

我們會看到 Dynamo 追蹤的輸出:

def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
    # File: example.py:5, code: z = (x - y) ** 2
    sub = l_x_ - l_y_
    z = sub ** 2
    # File: example.py:6, code: return z.sum()
    sum_1 = z.sum()
    return (sum_1,)

我們將此稱為函數針對給定輸入的圖(或追蹤)。它通過 FX 圖表示。我們將簡單地將 FX 圖視為一個儲存函數調用列表的容器。

我們首先應該注意的是,該圖是一個 PyTorch 操作的線性序列。1 Dynamo 記錄所有 PyTorch 操作並按順序儲存它們。 例如,它將 z = (x - y) ** 2 分成兩個組成操作,sub = l_x_ - l_y_z = sub ** 2

當我們說追蹤是線性的時,我們指的是沒有分支或任何控制流程。為了說明這一點,請考慮:

import torch

@torch.compile
def fn(x, n):
    y = x ** 2
    if n >= 0:
        return (n + 1) * y
    else:
        return y / n

x = torch.randn(200)
fn(x, 2)

當使用 TORCH_LOGS=graph_code 執行時,它會返回:

def forward(l_x_: torch.Tensor):
    # File: example.py:5, code: y = x ** 2
    y = l_x_ ** 2
    # File: example.py:7, code: return (n + 1) * y
    mul = 3 * y
    return (mul,)

我們看到 Dynamo 完全從追蹤中移除了 if 語句,只記錄了使用輸入執行的操作。

因此,應該清楚的是,函數的追蹤取決於輸入。特別是,這意味著追蹤不是在我們編寫 @torch.compile 時生成的,而是在我們使用實際參數執行函數 fn(x, 2) 時生成的。

另一個值得注意的有趣的事情是,Dynamo 移除了函數的第二個參數。相反,它將其視為常數,並在圖中記錄了操作 n + 1 的結果。這是 Dynamo 的另一個特性:Dynamo 會將任何非張量值視為常數……除了整數。現在讓我們看看整數有什麼特別之處。

Dynamo 的最後一個定義屬性是它知道如何處理動態形狀。符號形狀是指 Dynamo 追蹤形狀(更廣泛地說是整數)的能力,而不是將它們保留為常數。這允許避免重新編譯並部署適用於生產中任何大小的通用模型。動態形狀出現的主要例子是批量大小,我們可能會使用固定的批量大小訓練模型,然後執行任意批量大小的推論,或者處理文字或音訊時遇到的可變序列長度。

我們可以通過再次執行上面的例子來看到這一點:

import torch

@torch.compile
def fn(x, n):
    y = x ** 2
    if n >= 0:
        return (n + 1) * y
    else:
        return y / n

x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)

在這種情況下,TORCH_LOGS=graph_code 生成了另外兩個圖:

# Graph for n==2 omitted

def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
    # File: a.py:5, code: y = x ** 2
    y = l_x_ ** 2

    # File: a.py:7, code: return (n + 1) * y
    add = l_n_ + 1
    mul = add * y
    return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
    # File: a.py:5, code: y = x ** 2
    y = l_x_ ** 2

    # File: a.py:9, code: return y / n
    truediv = y / l_n_
    return (truediv,)

Dynamo 檢測到一個整數在第一次調用後更改了其值並開始追蹤它。我們看到這些圖是通用的,並且通過 SymInt 類型的物件符號化地追蹤變數 n

如果在這些調用之後,我們調用 fn(x, 4),Dynamo 不會重新編譯,而是重複使用已經追蹤的圖。

總結一下:1. Dynamo 是一個 Python 追蹤器 2. 給定一些輸入,它會返回一個 FX 圖,其中包含已執行的 PyTorch 函數 3. 如果它檢測到整數在調用之間發生變化,它也可以追蹤整數 4. 它會特殊化 (specialize) 任何不是張量或標量的值

當然,Dynamo 還有很多事情要做,例如弄清楚何時需要重新追蹤、重寫函數的位元組碼、實現圖中斷……為了使介紹簡短,我們將在後續章節中逐步討論所有這些內容。

PEP 523:向 CPython 添加 frame 評估 API

現在想像一下,我們被賦予了實現 Dynamo 的任務。我們從哪裡開始呢?對我們來說非常方便的是,PEP 523 與 Python 3.6 一起發布。這個 PEP 旨在允許第三方為 Python 建立 JIT 編譯器。讓我們看看如何。

CPython 註解:CPython 在內部實現為 堆疊機器。Python 程式被編譯成 位元組碼,然後由這個直譯器執行。要了解更多關於這些位元組碼的信息,請參閱標準庫中的 dis 模組。另請參閱 開發者文件,以了解 CPython 直譯器的簡介。我們假設讀者熟悉堆疊機器的概念。

PEP 523 公開了一個 API,用戶可以在其中添加自定義的逐函數直譯器。然後,CPython 將使用此直譯器而不是它自己的來執行該函數。為了能夠執行該函數,在進入時,CPython 為自定義直譯器提供以下內容:- 函數的位元組碼 - 函數參數的值(即,局部變數)及其名稱 - 全域變數的值及其名稱 - 內建函數,例如 absprint

你可以在這裡看到所有的欄位。2

總之,CPython 提供使用者解譯器所有執行函式所需資訊。3

透過此 API,我們可以實作一個追蹤器,方法是實作一個解譯器,執行程式碼並在圖中記錄在此執行期間發生的所有 PyTorch 運算。這正是 Dynamo 所做的事情。

Dynamo 使用此 CPython API 來解析所有這些物件,並將它們打包到一個 Python 結構中。完成後... 它會從 C 回到 Python。除了這段與 CPython 通訊的程式碼之外,Dynamo 完全以 Python 實作。

應該很清楚的是,裝飾器 @torch.compile 的工作是安裝必要的 scaffolding,以便在呼叫函式時將 bytecode、args、全域變數等等傳遞給 Dynamo。再次強調,@torch.compile 實際上沒有編譯任何東西。

用 Python 實作 CPython

所以,我們又回到了 Python 的世界。我們有一個函式的 bytecode,以及執行它所需的所有上下文。特別是,我們已抵達 _convert_frame_assert。這是裝飾器 torch.compile 返回的函式!我們從 _dynamo.optimize 進入這個函式。裝飾器 torch.compile 只是 _dynamo.optimize 周圍的一個友善 API。

在開始實作 Python 解譯器之前,我們要定義一個 IR(中介表示法)。特別是,我們希望將所有區域變數和全域變數包裝在我們自己的內部類別中。這讓我們可以更好地追蹤這些物件,並將 Dynamo 視為相同方式處理的物件分組在一起。

內部類別結構的父類別是 VariableTracker,代表 Dynamo 理解的不同物件。例如,ListVariable 代表一個 list 物件,並在內部保存一個 VariableTrackers 清單VariableTracker 的另一個例子是 ConstantVariable。ConstantVariable 包裝了所有 Dynamo 認為是常數的物件。我們還有需要特別注意的物件的特殊子類別,例如 TensorVariable。所有這些內部類別都在 torch/_dynamo/variables 資料夾中定義。

Python 物件在 VariableBuilder._wrap 中被包裝到它們相應的 VariableTracker 類別中。這個函式只是一個非常長的 elif 鏈,試圖遞迴地將 Python 輸入模式比對到適當的 VariableTracker 類型。

偵錯提示。當我們從 dynamo 獲得意想不到的結果時,有時是由於 builder 造成的。如果 builder 的邏輯錯誤,有時 Dynamo 可能會將變數包裝在不正確的 VariableTracker 類型中,這可能會在以後造成問題。查看錯誤中出現的 VariableTracker 類型,以及在遇到 Dynamo 錯誤時拋出例外狀況的 VariableTracker 方法,是相當有用的。特別是,有時我們會發現一個物件被追蹤為 UserDefinedObjectVariable(這是 Dynamo 的 catch-all 類別),而它應該被追蹤為更具體的東西。在這些情況下,通常要歸咎於 SourceBuilder.__call__ 邏輯。

偵錯提示。當使用 TORCH_LOGS=dynamo 執行程式時,印出的 artifacts 之一是以下形式的行:

TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]

這是原始程式的 bytecode 以及該點的堆疊狀態。這對於找出物件沒有被追蹤到正確的 VariableTracker 中非常有用。

好的,所以我們有一個追蹤器的 IR,現在我們只需要重新實作 CPython 的堆疊機器。這是由 InstructorTranslatorBasesymbolic_convert.py 中實作的。

InstructionTranslatorBase 大約有 200 種方法,實作了幾乎所有 Python bytecode。舉例來說,我們可以看到 BUILD_LIST 的實作

def BUILD_LIST(self, inst):
    items = self.popn(inst.argval)
    self.push(ListVariable(items, mutation_type=ValueMutationNew()))

這是由諸如 l = [2, 3, 4] 之類的結構產生的 bytecode。在這種情況下,由於有三個元素,因此產生的 bytecode 為 BUILD_LIST 3。這意味著我們從堆疊的頂部彈出頂部的 3 個元素,並將由這三個元素組成的新清單物件推到堆疊的頂部。

產生輸出圖

有了以符號方式執行 Python 程式碼的方法,我們就可以提取在給定一些輸入的情況下,在程式的符號執行期間發生的 PyTorch 運算。這在 Dynamo 中是透過 OutputGraph 物件實作的。OutputGraph 物件 繫結到一個 `InstructionTranslator 物件,它追蹤建立 FX 圖所需的所有資料,該圖將由 Dynamo 返回。

FX 圖的所有輸入和中間元素都是 fx.Nodes。在 Dynamo 中,fx.Nodes 被包裝在 fx.Proxys 中。fx.Proxys 用於建立 FX 圖。特別是,它們將對其執行的每個 PyTorch 運算記錄到圖中。您可以透過呼叫 create_proxy 來建立一個要新增到圖中的新運算。然後,我們可以透過函式 wrap_fx_proxy 將其新增到圖中。

圖 (graph) 儲存對張量 (tensor) 的操作… 以及對符號整數 (symbolic integer) 的操作。我們稍後會討論符號整數,但首先我們要討論 Dynamo 如何解決一個相當重要的正確性問題。

使 Dynamo 可靠:防護 (Guards)

到目前為止,我們已經有一種可以完全忽略控制流程 (control flow) 來追蹤程式的方法。為此,我們重新實作了整個 CPython… 如果這聽起來有點過度,那是因為確實如此。torch.jit.trace 已經實作了這一點,而不需要所有這些機制,那麼問題是什麼呢?

torch.jit.trace 的問題,正如其文件中警告的那樣,僅在被追蹤的程式與資料無關時才有效。換句話說,僅當程式本身是線性的時才有效。這意味著編寫程式時不使用 if-else、for-while 迴圈、例外處理。更重要的是,我們使用的任何函式庫都不能使用任何控制流程!總而言之,在像 Python 這樣動態的語言中不使用控制流程,實際上是一個巨大的限制。

JAX 通過始終重新追蹤 (retrace) 並且在重新追蹤後快取圖來解決這個問題。另一方面,Dynamo 使用防護 (guards) 來避免每次都重新追蹤整個程式。

防護是一種假設(關於輸入的布林運算式),為了針對一組範例輸入專門化 (specialize) 一個 frame。只有當這些假設在新輸入上成立時,重用圖才是有效的。

例如,函數的任何常數輸入 (constant input),例如字串 (string),都會安裝一個防護,聲明該輸入應該是 str 類型,並且等於我們傳遞的字串。執行

import torch

@torch.compile
def fn(a, b):
    return a * len(b)

fn(torch.arange(10), "Hello")

使用 TORCH_LOGS=guards 列印(以及其他防護)

___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'

這表示「區域變數 b 應該具有特定的類型(在本例中為 str,由常數 9433... 表示),並且其值應該是 'Hello'」。如果我們然後再次執行該函數,傳遞不同的引數

import torch

@torch.compile
def fn(a, b):
    return a * len(b)

fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")

我們可以通過執行 TORCH_LOGS=recompiles 來查看失敗的防護

Recompiling function fn in script.py:3
triggered by the following guard failure(s):
     - L['b'] == 'Hello'

函數的輸入被包裝在 builder 中 以及 在程式執行期間,防護會被累積。我們將在下一節中展示更多防護的範例,但首先讓我們討論來源 (sources)。

來源追蹤如何從進入當前 frame 時存在的原始區域或全域變數中重建變數。特別是,它追蹤原始的區域和全域物件以及它們包含的任何物件。在

def foo(x: Tensor, y: List[Tensor]):
    a = x * y[0]
    return a * x

xy 具有 LocalSource 作為它們的來源,並且 y[0] 具有 GetItemSource,它在內部儲存一個 LocalSource。另一方面,a 將沒有來源,因為它是一個僅存在於 fx 圖中的中間變數。

所有這些都在 torch/_dynamo/source.py 中定義。我們可以在以下範例中看到由 GetItemSource 產生的防護

import torch

@torch.compile
def fn(x, l):
    return x * len(l[0])

fn(torch.randn(8), ["Hi", "Hello"])

產生以下防護

___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'

在這裡,我們看到由 GetItemSource[0][1])產生的程式碼包裝了 LocalSourceL['l'])。

至此,有了來源和防護,我們就可以實作一個快取系統,以避免每次都必須重新追蹤而導致的重新編譯。我們將在後續內容中更詳細地討論這個快取系統。

細心的讀者會注意到,這還沒有解釋為什麼我們需要對 Python 直譯器 (interpreter) 進行如此精細的控制,以至於需要重新實作它。我們展示的防護範例取決於輸入物件,所以我們仍然可以在執行函數之前計算這些防護。換句話說,我們可以在 torch.jit.trace 之上實作這個防護系統,並以更少的工作量獲得相同的功能… 進入符號形狀 (symbolic shapes)。

符號形狀

我們在介紹中討論的另一個重點是 Dynamo 知道如何追蹤整數。為了實作這一點,我們使用一個符號類別 torch.SymInt,它的作用類似於 int,但它記錄了在輸出 FX 圖中對它執行的所有操作。4 我們已經在介紹符號整數追蹤時看到了這個類別。

現在讓我們討論在 Dynamo 中定義符號形狀追蹤的三個屬性,以及如何實作它們。

預設靜態 (Static by default)

Dynamo 假設每個整數,無論是輸入還是張量的形狀,預設都是靜態的。換句話說,在函數的第一次執行中,不會追蹤任何整數。然後,僅當它檢測到整數或形狀在執行期間改變了值,它才會追蹤它並生成一個通用於該變數的圖。

我們已經在介紹中使用整數看到了這種行為。現在讓我們來看一個使用張量形狀的範例。

import torch

@torch.compile
def fn(a, b):
    return a.shape[0] * a * b

fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))

使用 TORCH_LOGS=graph_code 執行這個程式,我們看到這兩個呼叫被追蹤為

def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
    mul = 4 * l_a_
    mul_1 = mul * l_b_
    return (mul_1,)

def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
    size = l_a_.size()
    getitem = size[0]
    mul = getitem * l_a_
    mul_1 = mul * l_b_
    return (mul_1,)

在第一個圖中,形狀被追蹤為常數,但是一旦它改變,它就會使用 SymInt 來符號地追蹤它。通常,查看中間值形狀的更簡單方法是使用 TORCH_LOGS=graph_sizes 執行程式

TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)

在這裡我們可以發現兩個張量引數的第一個維度是動態的,因為它由 s0 變數表示。

我們可以通過執行 TORCH_LOGS=guards 來找到 Dynamo 如何實作這一點

# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])

# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])

L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]

我們看到在第一次呼叫時,防護檢查張量是否具有一些固定的尺寸和步幅 (strides)。這些防護在第二次執行時失敗,所以它重新追蹤。由於是 int 防護失敗,因此在第二次迭代中,它會符號地追蹤這個 int,並在這個更通用的 kernel 上安裝更通用的防護。

編譯效能提示。如果您知道一個維度的大小會變動,您可以在呼叫 torch.compile 之前,呼叫 torch._dynamo.mark_dynamic 將其標記為動態。 這將避免第一次使用靜態形狀進行編譯。 還有其他有用的實用函數,例如 maybe_mark_dynamicmark_static。 您也可以透過呼叫 torch.compile(dynamic=True) 來追蹤所有整數和形狀。 這主要用於除錯目的。

0、1 永遠都會被特化

無論我們是否將一個維度標記為動態,如果我們傳入一個輸入,其中該維度為 0 或 1,Dynamo 都會將其追蹤為非動態,並且會為其產生一個特定的圖。 這就是為什麼在上面的例子中,我們會找到 2 <= L['a'].size()[0] 形式的守衛。

做出這個選擇有幾個原因。 有兩個特別重要 - 張量只有在其任何維度為零時才是空的 - 張量只有在其中一個步幅為一時才能是連續的

這個策略決策不適用於普通的 Python 整數; 如果我們認為應該動態編譯 Python 整數,我們預設不會特化它們; 相反,它是否被特化取決於它的使用方式。

鴨子定型 (Duck shaping)

Dynamo 執行我們所謂的「鴨子定型」。 如果兩個動態整數在追蹤時具有相同的值,我們會假設它們相等並對其進行守衛。 實際上,這意味著我們不是在上面的例子中有兩個符號 s0s1,而是將它們統一為 s0 並具有守衛 L['b'].size()[0] == L['a'].size()[0]。 這使得能夠在編譯器中執行融合,同時能夠產生足夠通用的核心。

符號整數上的守衛

我們現在了解了符號形狀如何在較高的層面上實現以及它們所擁有的屬性。 現在,為什麼符號形狀迫使我們走上如此棘手的道路,以至於需要控制 CPython 解釋器? 考慮以下範例

import torch

@torch.compile(dynamic=True)
def fn(a):
    if a.shape[0] * 2 < 16:
        return a
    else:
        return a + 1

fn(torch.randn(8))

此程式碼具有 2*L['a'].size()[0] >= 16 形式的守衛。 就函數的輸入而言,這是一個重要的守衛,但它是在程式執行過程中註冊的。 甚至更重要的是,在我們看到條件基於 SymNodeVariable 引數的 if 陳述式之前,我們無法知道是否需要此守衛。 這種條件對於 torch.jit.trace 來說是不可見的,並且需要對 Python 程式碼進行深入分析。

除錯提示 使用 TORCH_LOGS=dynamo 執行此程式碼會告訴我們在哪裡新增了此守衛

eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)

在那裡放置一個中斷點並查看回溯對於理解守衛來自哪裡非常有用。

使 Dynamo 完整:圖形中斷

使用我們討論的所有工具,我們有一個追蹤器,可以追蹤張量和整數上的 PyTorch 操作,並且有一個快取系統,知道何時可以重複使用先前追蹤的圖形,以及何時需要重新追蹤。 所有這些都執行任意 Python 程式碼!

這只有一個小問題。 「執行任意 Python 程式碼」的說法可能有點過於概括。 Dynamo 實現了 Python 的很大一部分,但它是否實現了更複雜的部分,例如協程或非同步? 它是否實現了整個 Python 標準函式庫? NumPy 也有一個 Python API。 torch.compile 也理解 NumPy 嗎? 還有 Django? 5

Python 的生態系統非常龐大,其中很大一部分是用其他效能更高的語言(如 C++ 或 Rust)編寫的,它只是公開了 Python 繫結。 Dynamo 沒有希望追蹤透過用 C++ 實現的 Python 物件。 當追蹤器發現它不理解的操作時,它能做什麼?

機器學習追蹤器處理這個問題的常用方法是通知使用者他們被哪個操作阻塞,並完全放棄追蹤。 這會在 PyTorch 的情況下造成真正的可用性問題,因為它的使用者已經習慣了它給予他們的靈活性。 作為一個真實世界的例子,doctr_det_predictor 模型使用 NumPy 和 cv2 函式庫來 後處理模型的結果

這裡是另一個可以訪問 CPython 的有趣的地方。 Dynamo 沒有報錯,而是可以讓 CPython 運行有問題的程式碼! 為此,Dynamo 在追蹤時產生一個圖,其中包含有問題的程式碼之前的所有操作,以及一個包含之後的所有操作的圖。 6 然後,在執行時,它將委託給 CPython 來執行第一個圖,然後執行有問題的程式碼,然後執行第二個圖。 這種停止追蹤並產生多個圖的過程稱為圖形中斷

一個小小的坦白:我在整個介紹和第一部分都說謊了。 Dynamo 並不是產生一個圖,而是多個圖! 就所有實際用途而言,在第二個圖之後開始重新追蹤可以被認為是開始追蹤一個新函數。 圖形中斷後的新圖將有其自己的守衛、其新的本地變數集等等。

為了討論如何實作圖形中斷(graph breaks),我們需要先回顧 Dynamo 如何與 CPython 互動。透過 PEP 523,CPython 允許使用者使用自己的 frame 評估機制。我們之前沒有討論的是,CPython 也暴露了自己的 frame 評估機制供他人使用。 Dynamo 利用這一點,讓快速的 CPython 直譯器執行編譯後的程式碼。對於沒有圖形中斷的函式,呼叫該函式兩次且使用相同參數的程式的完整追蹤/執行過程如下所示:

  1. 在第一次呼叫函式時:

    1. Dynamo 將函式追蹤成 FX 圖形。

      1. FX 圖形由編譯器 (Inductor) 編譯成有效率的底層程式碼……但那是另一天的故事了。

    2. 它重寫了函式的 bytecode,使其僅僅呼叫編譯後的函式。

    3. 它將這個新的 bytecode 給 CPython 並要求它執行 [這裡]。

  2. 在第二次呼叫函式時:

    1. 它根據新的參數檢查第一次呼叫的 guards [這裡]。由於它們與之前的參數相同,因此通過。

    2. 它要求 CPython 執行與這些 guards 相關聯的 bytecode [這裡]。

這個過程本身看起來過於複雜。為什麼要產生新的 bytecode 並要求 CPython 執行它,而不是簡單地建立一個 C++ binding 到編譯後的函式並執行它? 嗯,這種模式允許我們實作圖形中斷! 由圖形中斷產生的 bytecode 具有以下結構:

  1. 執行第一個圖形的 bytecode。

  2. 將 stack 恢復到 CPython 執行第一個圖形後的狀態的 bytecode。 它也會重播對此時可見的 local 或 global 變數的任何修改。

  3. 導致 Dynamo 圖形中斷的 bytecode。

  4. 執行第二個圖形的 bytecode。

讓我們在一個簡單的例子中看看這個:

import torch

@torch.compile
def fn(a):
    b = a + 2
    print("Hi")
    return b + a

fn(torch.randn(4))

使用 TORCH_LOGS=bytecode 運行它會向我們顯示初始 bytecode 和修改後的 bytecode。

MODIFIED BYTECODE fn script.py line 3
 0 LOAD_GLOBAL              1 (__compiled_fn_0)
 2 LOAD_FAST                0 (a)
 4 CALL_FUNCTION            1
 6 STORE_FAST               3 (graph_out_0)
 8 LOAD_GLOBAL              0 (print)
10 LOAD_CONST               2 ('Hi')
12 LOAD_FAST                3 (graph_out_0)
14 LOAD_CONST               3 (0)
16 BINARY_SUBSCR
18 STORE_FAST               1 (b)

20 CALL_FUNCTION            1
22 LOAD_GLOBAL              2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST                0 (a)
28 LOAD_FAST                1 (b)
30 CALL_FUNCTION            3
32 RETURN_VALUE

MODIFIED BYTECODE resume_in_fn script.py line 6
 0 LOAD_GLOBAL              1 (__compiled_fn_2)
 2 LOAD_FAST                2 (b)
 4 LOAD_FAST                1 (a)
 6 CALL_FUNCTION            2
 8 UNPACK_SEQUENCE          1
10 RETURN_VALUE

我們可以看到修改後的 bytecode 被分成兩個函式,fn,即原始函式,以及一個名為 resume_in_fn 的函式。 第二個函式是由 Dynamo 創建的函式,用於實作從圖形中斷開始執行的程式。 這通常被稱為continuation function(後續函式)。 這個後續函式只是用正確的參數呼叫第二個編譯後的函式。 初始函式的程式碼被重寫以實作我們之前描述的策略。

  • L0-4. 呼叫編譯後的函式 (a + 2)。

  • L6. 將其結果儲存在一個名為 graph_out_0 的 local 變數中。 graph_out_0 是一個 tuple。

  • L8-18. 將 stack 恢復到圖形中斷時的狀態。

  • L20. 執行導致圖形中斷的程式碼。

  • L22-32. 呼叫編譯後的後續函式 (a + b)。

Dynamo 中 stack 的程式碼生成委託給 VariableTracker 子類。 Dynamo 中的每個 VariableTracker 物件都有一個 reconstruct 方法,該方法生成必要的 bytecode 以在 stack 上建立它所代表的 python 物件。

除錯提示。 圖形中斷會損害效能,因此最好避免它們。 使用 TORCH_LOGS=graph_breaks 運行程式是找到我們的程式命中了多少個圖形中斷的好方法。 它返回的資訊採用 VariableTracker 物件的形式,因此上面的除錯提示有時也有助於弄清楚是什麼導致了該圖形中斷。

結論

Dynamo 是一個複雜的軟體。 一旦您開始實作 CPython 直譯器,您就知道您正在進行一場冒險。 也就是說,我們希望這篇文章有助於揭開它的一些神秘面紗。

Dynamo(主要)是用 Python 實作的。 我們留下了許多指向我們討論過的程式碼片段的連結。 我們希望閱讀這些程式碼片段並 grep 呼叫它們的位置,或在它們上放置中斷點並查看呼叫 stack 有助於理解程式碼庫的其餘部分。

當然,學習軟體如何運作的最佳方法是擴展它。 在這種情況下,最好的方法是查看 github 上未解決的 dynamo issues。 其中許多只需要對程式碼進行非常小的修改,一旦您找到需要進行這些修改的位置。

腳註

1

在文獻中,這被稱為 Directed Acyclical Graph (DAG)。

2

所有這些 binding 程式碼都位於 torch/csrc/dynamo/eval_frame.c 中。

3

在 CPython 的術語中,所有這些物件的集合稱為 frame

4

還有 SymBoolSymFloat 類別。 在撰寫本文時,後者使用不多。

5

有趣的是,它確實理解 NumPy 程式碼! 請查看 這篇網誌文章文件。 現在,這只是因為我們使用 PyTorch 重新實作了 NumPy 才有可能。 祝你好運在 PyTorch 中實作 Django…

6

假設只有一段有問題的程式碼。 如果有更多,Dynamo 可以將程式碼分成它需要的許多圖形。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得為初學者和進階開發者提供的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源