捷徑

使用 Minifier

我們有一個非常方便的測試案例 Minifier,具有以下介面

def minifier(fail_f: fx.GraphModule, inps, module_fails):
    """
    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.

    Does 2 main strategies:
    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
        tries replacing quarter of the graph, etc.

    >>> failing_function = fx.symbolic_trace(f)
    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

    note: module_fails returns True if it fails.
    ...

具體來說,它會接收您的 FX 圖形,並嘗試使用以下 4 種策略將其最小化(同時檢查結果圖形是否仍針對 module_fails 返回 True),直到無法再將其最小化為止。

  1. 截斷後綴:給定一個 FX 圖形,它會嘗試從圖形中移除一些後綴。例如,給定這個

def f(a):
    b = x * 2
    c = b + 3
    d = c / 4
    return d

它可能會嘗試截斷後綴,並得到

def f(a):
    b = x * 2
    c = b + 3
    return c

它以二元搜尋的方式嘗試這樣做,嘗試移除最後 1/2,然後是 3/4、1/4,然後是 7/8、5/8、3/8…

  1. 增量除錯:當然,移除後綴並不總足以最小化圖形。如果錯誤是由第一個指令引起的怎麼辦?因此,我們採用了一種受增量除錯啟發的方法 - 我們嘗試移除圖形中的中間節點。與後綴不同,移除的節點仍然存在相依性。因此,我們沒有完全移除它們,而是將它們提升為輸入。例如,給定上述範例

def f(a):
    b = x * 2
    c = b + 3
    d = c / 4
    return d

我們可能會移除一個中間節點(在本例中為 c)。

def f(a, c):
    b = x * 2
    d = c / 4
    return d

最後,還有 2 個輔助策略 - 消除無效程式碼和移除未使用的輸入。這些都相當不言自明。

那麼,讓我們看一個玩具範例。讓我們假設我們的圖形如果其中有「乘法」,就會失敗。讓我們建立一個失敗的圖形。

import torch
import torch.fx as fx
from functorch.compile import minifier

def failing_f(x, y):
    y = torch.ops.aten.div(x, y)
    x = torch.ops.aten.add(x, 3)
    x = torch.ops.aten.mul(x, y)
    return torch.ops.aten.sub(x, y)

inps = [torch.randn(3), torch.randn(3)]

def pass_checker(fx_g, inps):
    return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))

min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)
[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)
    registered at aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: FuncTorchBatched
  previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338
       new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)
Started off with 7 nodes
###################
Current size: 7
###################
Strategy: Remove suffix

SUCCESS: Removed [4:7)

###################
Current size: 6
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4

###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 2 inputs

###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]



def forward(self, div, add):
    mul = torch.ops.aten.mul(add, div);  add = div = None
    return (mul,)
    
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
  for _ in range(5):
    f(*inps)

瞧!我們的圖形現在是一個仍然會失敗的最小範例。

由於此 Minifier 的主要用例(目前)是針對 NVFuser 重現,因此為了方便起見,我們列印出一個字串,該字串建立了一個獨立的重現,以使用 NVFuser 執行最小化的圖形。

請注意,在實務上,我們提供了 2 個主要的「圖形檢查器」- check_nvfuser_subprocesscheck_nvfuser_correctness_subprocess。這些分別用於檢查錯誤和正確性(即結果是否與 eager 相符)。這些可以使用如下:

from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
minifier(failing_graph, inps, check_nvfuser_subprocess)

但是,假設您使用的是 AOTAutograd,那麼還有另一個問題 - 您如何一開始就取得 FX 圖形以傳遞給 Minifier?一種可能的方法是簡單地使用 print_compile

from functorch.compile import aot_function

from functorch.compile import print_compile
# Or...
def print_compile(fx_g, _):
    print(fx_g.code)
    return fx_g

def foo(x):
    return x.cos().cos()
inp = torch.randn(3, requires_grad=True)
aot_function(foo, print_compile)(inp)
def forward(self, primals_1):
    cos = torch.ops.aten.cos(primals_1)
    cos_1 = torch.ops.aten.cos(cos)
    return [cos_1, primals_1, cos]
    



def forward(self, primals_1, cos, tangents_1):
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(primals_1);  primals_1 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1]
    
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)

但是,這不會提供輸入,也不會處理可能儲存在圖形中的任何張量常數。為了解決這個問題,我們還有另一個名為 debug_compile 的「編譯器」。它只會列印出一個字串,可以複製貼上並從另一個檔案執行。它利用 FX 的 to_folder 功能將圖形序列化到磁碟,以及任何常數。

您可以將其應用於 fw_compiler 以傾印正向圖形,或應用於 bw_compiler 以傾印反向圖形。

from functorch.compile import memory_efficient_fusion, debug_compile

memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)
##############################################################
# To minimize FX graph, copy and paste the below and run it  #
##############################################################

import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule().cuda()

with torch.jit.fuser("fuser2"):
  # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
  minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)

那麼,讓我們複製貼上它,看看它是如何工作的 - 請注意,我進行了一些小的修改,以便在 CPU 上執行,並使用之前的「如果圖形中有乘法,則圖形失敗」檢查器。

import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule()

minifier(fx.symbolic_trace(mod), inps, pass_checker)
Started off with 10 nodes
###################
Current size: 10
###################
Strategy: Remove suffix

SUCCESS: Removed [6:10)

###################
Current size: 8
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4

###################
Current size: 8
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 3 inputs

###################
Current size: 7
###################
Strategy: Remove suffix

SUCCESS: Removed [4:7)

###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs

###################
Current size: 5
###################
Strategy: Delta Debugging
SUCCESS: Removed (2:3] - Went from 2 placeholders to 3

###################
Current size: 5
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs

###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]



def forward(self, tangents_1, neg):
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    return (mul,)
    
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
  for _ in range(5):
    f(*inps)
(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])

希望這對您有所幫助 :)

文件

存取 PyTorch 的完整開發者說明文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源