• 文件 >
  • AOT Autograd - 如何使用與最佳化?
捷徑

AOT Autograd - 如何使用與最佳化?

Open In Colab

背景

在本教學課程中,我們將學習如何使用 AOT Autograd 加速深度學習模型的訓練。

作為背景知識,AOT Autograd 是一個協助開發人員加速 PyTorch 訓練的工具包。廣義來說,它具有兩個主要功能

  • AOT Autograd 會預先追蹤正向和反向圖。預先存在正向和反向圖有助於聯合圖最佳化,例如重新計算或激活檢查點。

  • AOT Autograd 提供簡單的機制,可透過深度學習編譯器(例如 NVFuser、NNC、TVM 和其他編譯器)編譯提取的正向和反向圖。

您將學到什麼?

在本教學課程中,我們將瞭解如何將 AOT Autograd 與後端編譯器結合使用,以加速 PyTorch 模型的訓練。更具體地說,您將學習

  • 如何使用 AOT Autograd?

  • AOT Autograd 如何使用後端編譯器執行運算融合?

  • AOT Autograd 如何啟用特定於訓練的最佳化,例如重新計算?

那麼,讓我們開始吧。

設定

讓我們設定一個簡單的模型。

import torch

def fn(a, b, c, d):
    x = a + b + c + d
    return x.cos().cos()
# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()

使用 AOT Autograd

現在,讓我們使用 AOT Autograd 並查看提取的正向和反向圖。在內部,AOT 使用基於 __torch_dispatch__ 的追蹤機制來提取正向和反向圖,並將它們封裝在 torch.Fx GraphModule 容器中。請注意,AOT Autograd 追蹤與通常的 Fx 符號追蹤不同。AOT Autograd 使用 Fx GraphModule 僅僅是為了表示追蹤的圖(而不是為了追蹤)。

然後,AOT Autograd 會將這些正向和反向圖傳送到用戶提供的編譯器。因此,讓我們編寫一個僅列印圖的編譯器。

from functorch.compile import aot_function

# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
    print(fx_module.code)
    return fx_module

# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos(add_2)
    cos_1 = torch.ops.aten.cos(cos)
    return [cos_1, add_2, cos]
    



def forward(self, add_2, cos, tangents_1):
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]
    

上面的程式碼列印了正向和反向圖的 Fx 圖。您可以看到,除了正向傳遞的原始輸入之外,正向圖還輸出了一些額外的張量。這些張量會被保存下來,以便在反向傳遞中用於梯度計算。我們稍後在討論重新計算時會回過頭來討論這些問題。

運算器融合

現在我們已經瞭解瞭如何使用 AOT Autograd 列印正向和反向圖,讓我們使用 AOT Autograd 來使用一些實際的深度學習編譯器。在本教學課程中,我們使用 PyTorch 神經網路編譯器 (NNC) 為 CPU 裝置執行逐點運算器融合。對於 CUDA 裝置,合適的替代方案是 NvFuser。因此,讓我們使用 NNC

# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile
from functorch.compile import ts_compile

# Lets compile the forward and backward through ts_compile.
aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)

# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs

res = aot_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)

讓我們對原始函數和 AOT Autograd + NNC 編譯後的函數進行基準測試。

# Lets write a function to benchmark the forward and backward pass
import time
import statistics

def bench(fn, args, prefix):
    warmup = 10
    iterations = 100

    for _ in range(warmup):
        ref = fn(*args)
        ref.sum().backward()
    
    fw_latencies = []
    bw_latencies = []
    for _ in range(iterations):
        for arg in args:
            arg.grad = None

        fw_begin = time.perf_counter()
        ref = fn(*args)
        fw_end = time.perf_counter()

        loss = ref.sum() 

        bw_begin = time.perf_counter()
        loss.backward()
        bw_end = time.perf_counter()

        fw_latencies.append(fw_end - fw_begin)
        bw_latencies.append(bw_end - bw_begin)
    
    avg_fw_latency = statistics.mean(fw_latencies) * 10**6
    avg_bw_latency = statistics.mean(bw_latencies) * 10**6
    print(prefix, "Fwd = " + str(avg_fw_latency) + " us", "Bwd = " + str(avg_bw_latency) + " us", sep=', ')
large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]

# Benchmark the Eager and AOT Autograd functions
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us
AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us

在 NNC 的幫助下,AOT Autograd 加速了正向和反向傳遞。如果我們查看之前列印的圖表,所有運算器都是逐點的。逐點運算器受記憶體頻寬限制,因此可以從運算器融合中受益。仔細觀察這些數字,反向傳遞的速度提升更大。這是因為正向傳遞必須輸出一些中間張量,以便在反向傳遞中計算梯度,這使得它無法節省一些記憶體讀寫操作。但是,反向圖中不存在這樣的限制。

重新計算(又稱激活檢查點)

重新計算(通常稱為激活檢查點)是一種技術,在這種技術中,我們不是保存一些激活以供反向使用,而是在反向傳遞**期間**重新計算它們。重新計算可以節省記憶體,但會產生效能開銷。

但是,在存在融合編譯器的情況下,我們可以做得更好。我們可以重新計算易於融合的運算器以節省記憶體,然後依靠融合編譯器來融合重新計算的運算器。這減少了記憶體和執行時間。有關更多詳細資訊,請參閱此討論文章

在這裡,我們將 AOT Autograd 與 NNC 結合使用,以執行類似類型的重新計算。在 __torch_dispatch__ 追蹤結束時,AOT Autograd 有一個正向圖和一個聯合正向-反向圖。然後,AOT Autograd 使用一個分割器來隔離正向圖和反向圖。在上面的示例中,我們使用了一個預設的分割器。在本實驗中,我們將使用另一個名為 min_cut_rematerialization_partition 的分割器來執行更智能的融合感知重新計算。分割器是可配置的,開發人員可以編寫自己的分割器並將其插入 AOT Autograd。

from functorch.compile import min_cut_rematerialization_partition

# Zero out the gradients so we can do a comparison later
a.grad, b.grad, c.grad, d.grad = (None,) * 4

# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.
# This will show us how the recomputation has modified the graph.
aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)
res = aot_fn(a, b, c, d).sum().backward()
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos(add_2)
    cos_1 = torch.ops.aten.cos(cos);  cos = None
    return [cos_1, add_2]
    



def forward(self, add_2, tangents_1):
    cos = torch.ops.aten.cos(add_2)
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]
    

我們可以看到,與預設分割器相比,正向傳遞現在輸出的張量更少,並且在反向傳遞中重新計算了一些運算。讓我們現在嘗試使用 NNC 編譯器來執行運算器融合(請注意,我們還有一個包裝函數 - memory_efficient_fusion,它在內部使用 min_cut_rematerialization_partition 和 Torchscript 編譯器來達到與以下程式碼相同的效果)。

# Lets set up the partitioner and NNC compiler.
aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)

# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs

res = aot_recompute_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)

最後,讓我們對不同的函數進行基準測試

bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
bench(aot_recompute_nnc_fn, large_inputs, "AOT_Recomp")
Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us
AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us
AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us

我們觀察到,正向和反向延遲都比預設分割器有所改進(並且比 eager 好得多)。正向傳遞中輸出較少,反向傳遞中輸入較少,以及融合,允許更好地利用記憶體頻寬,從而進一步提高速度。

實際用法

對於 CUDA 裝置上的實際用法,我們將 AOTAutograd 封裝在一個方便的包裝器中 - memory_efficient_fusion。在 GPU 上使用它進行融合!

from functorch.compile import memory_efficient_fusion

文件

存取 PyTorch 的完整開發者文件

查看文件

教學課程

取得針對初學者和進階開發者的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源