使用 Minifier¶
我們有一個非常方便的測試案例 Minifier,具有以下介面
def minifier(fail_f: fx.GraphModule, inps, module_fails):
"""
Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
Does 2 main strategies:
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
tries replacing quarter of the graph, etc.
>>> failing_function = fx.symbolic_trace(f)
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
note: module_fails returns True if it fails.
...
具體來說,它會接收您的 FX 圖形,並嘗試使用以下 4 種策略將其最小化(同時檢查結果圖形是否仍針對 module_fails
返回 True),直到無法再將其最小化為止。
截斷後綴:給定一個 FX 圖形,它會嘗試從圖形中移除一些後綴。例如,給定這個
def f(a):
b = x * 2
c = b + 3
d = c / 4
return d
它可能會嘗試截斷後綴,並得到
def f(a):
b = x * 2
c = b + 3
return c
它以二元搜尋的方式嘗試這樣做,嘗試移除最後 1/2,然後是 3/4、1/4,然後是 7/8、5/8、3/8…
增量除錯:當然,移除後綴並不總足以最小化圖形。如果錯誤是由第一個指令引起的怎麼辦?因此,我們採用了一種受增量除錯啟發的方法 - 我們嘗試移除圖形中的中間節點。與後綴不同,移除的節點仍然存在相依性。因此,我們沒有完全移除它們,而是將它們提升為輸入。例如,給定上述範例
def f(a):
b = x * 2
c = b + 3
d = c / 4
return d
我們可能會移除一個中間節點(在本例中為 c)。
def f(a, c):
b = x * 2
d = c / 4
return d
最後,還有 2 個輔助策略 - 消除無效程式碼和移除未使用的輸入。這些都相當不言自明。
那麼,讓我們看一個玩具範例。讓我們假設我們的圖形如果其中有「乘法」,就會失敗。讓我們建立一個失敗的圖形。
import torch
import torch.fx as fx
from functorch.compile import minifier
def failing_f(x, y):
y = torch.ops.aten.div(x, y)
x = torch.ops.aten.add(x, 3)
x = torch.ops.aten.mul(x, y)
return torch.ops.aten.sub(x, y)
inps = [torch.randn(3), torch.randn(3)]
def pass_checker(fx_g, inps):
return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))
min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)
[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)
registered at aten/src/ATen/RegisterSchema.cpp:6
dispatch key: FuncTorchBatched
previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338
new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)
Started off with 7 nodes
###################
Current size: 7
###################
Strategy: Remove suffix
SUCCESS: Removed [4:7)
###################
Current size: 6
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4
###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 2 inputs
###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
def forward(self, div, add):
mul = torch.ops.aten.mul(add, div); add = div = None
return (mul,)
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
for _ in range(5):
f(*inps)
瞧!我們的圖形現在是一個仍然會失敗的最小範例。
由於此 Minifier 的主要用例(目前)是針對 NVFuser 重現,因此為了方便起見,我們列印出一個字串,該字串建立了一個獨立的重現,以使用 NVFuser 執行最小化的圖形。
請注意,在實務上,我們提供了 2 個主要的「圖形檢查器」- check_nvfuser_subprocess
和 check_nvfuser_correctness_subprocess
。這些分別用於檢查錯誤和正確性(即結果是否與 eager 相符)。這些可以使用如下:
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
minifier(failing_graph, inps, check_nvfuser_subprocess)
但是,假設您使用的是 AOTAutograd,那麼還有另一個問題 - 您如何一開始就取得 FX 圖形以傳遞給 Minifier?一種可能的方法是簡單地使用 print_compile
。
from functorch.compile import aot_function
from functorch.compile import print_compile
# Or...
def print_compile(fx_g, _):
print(fx_g.code)
return fx_g
def foo(x):
return x.cos().cos()
inp = torch.randn(3, requires_grad=True)
aot_function(foo, print_compile)(inp)
def forward(self, primals_1):
cos = torch.ops.aten.cos(primals_1)
cos_1 = torch.ops.aten.cos(cos)
return [cos_1, primals_1, cos]
def forward(self, primals_1, cos, tangents_1):
sin = torch.ops.aten.sin(cos); cos = None
neg = torch.ops.aten.neg(sin); sin = None
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin(primals_1); primals_1 = None
neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None
return [mul_1]
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)
但是,這不會提供輸入,也不會處理可能儲存在圖形中的任何張量常數。為了解決這個問題,我們還有另一個名為 debug_compile
的「編譯器」。它只會列印出一個字串,可以複製貼上並從另一個檔案執行。它利用 FX 的 to_folder
功能將圖形序列化到磁碟,以及任何常數。
您可以將其應用於 fw_compiler
以傾印正向圖形,或應用於 bw_compiler
以傾印反向圖形。
from functorch.compile import memory_efficient_fusion, debug_compile
memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)
##############################################################
# To minimize FX graph, copy and paste the below and run it #
##############################################################
import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule().cuda()
with torch.jit.fuser("fuser2"):
# check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)
那麼,讓我們複製貼上它,看看它是如何工作的 - 請注意,我進行了一些小的修改,以便在 CPU 上執行,並使用之前的「如果圖形中有乘法,則圖形失敗」檢查器。
import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule()
minifier(fx.symbolic_trace(mod), inps, pass_checker)
Started off with 10 nodes
###################
Current size: 10
###################
Strategy: Remove suffix
SUCCESS: Removed [6:10)
###################
Current size: 8
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4
###################
Current size: 8
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 3 inputs
###################
Current size: 7
###################
Strategy: Remove suffix
SUCCESS: Removed [4:7)
###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs
###################
Current size: 5
###################
Strategy: Delta Debugging
SUCCESS: Removed (2:3] - Went from 2 placeholders to 3
###################
Current size: 5
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs
###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
def forward(self, tangents_1, neg):
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
return (mul,)
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
for _ in range(5):
f(*inps)
(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])
希望這對您有所幫助 :)