捷徑

常見問題

作者: Mark Saroufim

torch.compile 支援訓練嗎?

torch.compile 支援訓練,使用 AOTAutograd 來捕捉反向傳播

  1. .forward() 圖形和 optimizer.step() 由 TorchDynamo 的 Python evalframe 前端捕捉。

  2. 對於 torchdynamo 捕捉的每個 .forward() 段,它使用 AOTAutograd 來產生一個反向圖形段。

  3. 每個正向和反向圖形對(可選)經過最小切割分割,以在正向和反向之間儲存最小狀態。

  4. 正向和反向對封裝在 autograd.function 模組中。

  5. 使用者程式碼呼叫 .backward() 仍然會觸發 eager 的 autograd 引擎,該引擎將每個編譯的反向圖形作為一個運算執行,同時也執行任何未編譯的 eager 運算的 .backward() 函數。

您是否支援分散式程式碼?

torch.compile 支援 DistributedDataParallel (DDP)。 正在考慮支援其他分散式訓練函式庫。

使用 Dynamo 處理分散式程式碼之所以具有挑戰性,主要原因在於 AOTAutograd 會展開前向傳播和反向傳播,並提供 2 個圖形供後端進行優化。對於分散式程式碼而言,這是一個問題,因為我們理想情況下希望將通訊操作與計算操作重疊。Eager PyTorch 以不同的方式實現這一點,針對 DDP/FSDP 使用自動微分鉤子、模組鉤子以及修改/突變模組狀態。如果直接應用 Dynamo,由於 AOTAutograd 編譯函數與分派器鉤子的互動方式,在反向傳播期間應在操作後立即運行的鉤子可能會延遲到整個反向傳播操作的編譯區域之後。

使用 Dynamo 優化 DDP 的基本策略已在 distributed.py 中概述,其主要思想是在 DDP bucket 邊界上進行圖形分割。

當 DDP 中的每個節點需要與其他節點同步其權重時,它會將其梯度和參數組織成 bucket,從而減少通訊時間,並允許節點將其梯度的一部分廣播到其他等待節點。

分散式程式碼中的圖形分割意味著您可以期望 Dynamo 及其後端優化分散式程式的計算開銷,但不能優化其通訊開銷。如果減少的圖形大小剝奪了編譯器融合的機會,則圖形分割可能會干擾編譯加速。但是,隨著圖形大小的增加,收益會遞減,因為目前大多數計算優化都是本地融合。因此,在實踐中,這種方法可能就足夠了。

我還需要匯出完整的圖形嗎?

對於絕大多數模型,您可能不需要,並且可以按原樣使用 torch.compile(),但在少數情況下,完整圖形是必要的,您可以通過簡單地運行 torch.compile(..., fullgraph=True) 來確保得到完整的圖形。這些情況包括:

  • 大規模訓練運行,例如需要流水線並行性和其他進階分片策略的 $250K+ 訓練。

  • 推理優化器,例如 TensorRTAITemplate,它們比訓練優化器更積極地依賴融合。

  • 行動裝置上的訓練或推理。

未來的工作將包括將通訊操作追蹤到圖形中,協調這些操作與計算優化,以及優化通訊操作。

為什麼我的程式碼崩潰了?

如果您的程式碼在沒有 torch.compile 的情況下運行良好,但在啟用它的情況下開始崩潰,那麼最重要的第一步是找出您的故障發生在堆疊的哪個部分。要對其進行故障排除,請按照以下步驟操作,並且只有在前一個步驟成功後才嘗試下一步。

  1. torch.compile(..., backend="eager"),它僅運行 TorchDynamo 前向圖形捕獲,然後使用 PyTorch 運行捕獲的圖形。如果這失敗了,那麼 TorchDynamo 就有問題。

  2. torch.compile(..., backend="aot_eager"),它運行 TorchDynamo 來捕獲前向圖形,然後運行 AOTAutograd 來追蹤反向圖形,而無需任何額外的後端編譯器步驟。然後,PyTorch eager 將用於運行前向和反向圖形。如果這失敗了,那麼 AOTAutograd 就有問題。

  3. torch.compile(..., backend="inductor"),它運行 TorchDynamo 來捕獲前向圖形,然後運行 AOTAutograd 來追蹤帶有 TorchInductor 編譯器的反向圖形。如果這失敗了,那麼 TorchInductor 就有問題。

為什麼編譯速度慢?

  • Dynamo 編譯 – TorchDynamo 具有內建的統計函數,用於收集和顯示每個編譯階段所花費的時間。可以通過在執行 torch._dynamo 之後調用 torch._dynamo.utils.compile_times() 來存取這些統計資訊。默認情況下,這會返回一個字串表示,其中包含每個 TorchDynamo 函數按名稱花費的編譯時間。

  • Inductor 編譯 – TorchInductor 具有內建的統計和追蹤函數,用於顯示每個編譯階段、輸出程式碼、輸出圖形可視化和 IR 轉儲所花費的時間。env TORCH_COMPILE_DEBUG=1 python repro.py。這是一個調試工具,旨在更容易地調試/理解 TorchInductor 的內部結構,其輸出看起來像 這樣。可以通過 torch._inductor.config.trace.* 啟用/禁用該調試追蹤中的每個檔案。預設情況下,配置文件和圖表都被禁用,因為它們的生成成本很高。有關更多示例,請參見 示例調試目錄輸出

  • 過度重新編譯 當 TorchDynamo 編譯函數(或其中一部分)時,它會對局部變數和全域變數做出某些假設,以便允許編譯器優化,並將這些假設表示為在運行時檢查特定值的 guard。如果任何這些 guard 失敗,Dynamo 將重新編譯該函數(或部分),最多 torch._dynamo.config.cache_size_limit 次。如果您的程式達到快取限制,您首先需要確定哪個 guard 失敗以及程式的哪個部分觸發了它。重新編譯分析器 自動執行將 TorchDynamo 的快取限制設定為 1 並在觀察模式 'compiler' 下運行程式的過程,該程式記錄任何 guard 失敗的原因。您應該確保運行程式的時間至少與您遇到問題時運行的時間(迭代次數)一樣長,並且分析器會在此持續時間內累積統計資訊。

為什麼您在生產環境中重新編譯?

在某些情況下,您可能不希望程式在預熱後出現意外編譯。例如,如果您在延遲關鍵應用程式中提供生產流量。為此,TorchDynamo 提供了一種替代模式,其中使用先前的編譯圖形,但不生成新的編譯圖形。

frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))

您如何加速我的程式碼?

加速 PyTorch 程式碼主要有 3 種方法:

  1. 通過垂直融合進行核心融合,垂直融合融合順序操作以避免過多的讀/寫。例如,融合 2 個後續的餘弦意味著您可以進行 1 次讀取 1 次寫入,而不是 2 次讀取 2 次寫入。水平融合:最簡單的例子是批處理,其中單個矩陣與一批示例相乘,但更常見的情況是分組 GEMM,其中一組矩陣乘法一起排程。

  2. 亂序執行:編譯器的一般優化,通過查看圖形中的確切資料依賴關係,我們可以確定執行節點的最佳時間以及可以重複使用哪些緩衝區。

  3. 自動工作放置:類似於亂序執行點,但通過將圖形的節點與物理硬體或記憶體等資源進行匹配,我們可以設計一個合適的排程。

以上是加速 PyTorch 程式碼的一般原則,但不同的後端將對優化什麼做出不同的權衡。例如,Inductor 首先處理它可以融合的任何內容,然後才生成 Triton 核心。

Triton 還提供了加速,因為它在每個流式多處理器中自動進行記憶體合併、記憶體管理和排程,並且設計用於處理分塊計算。

然而,無論您使用哪個後端,最好使用基準測試和觀察方法,嘗試使用 PyTorch profiler,以視覺方式檢查產生的核心,並嘗試了解背後的原因。

為什麼我看不到速度提升?

圖表斷裂 (Graph Breaks)

使用 dynamo 無法達到預期速度提升的主要原因是過多的圖表斷裂。那麼,什麼是圖表斷裂?

給定像這樣的程式碼:

def some_fun(x):
    ...

torch.compile(some_fun)(x)
...

Torchdynamo 會嘗試將 some_fun() 內的所有 torch/tensor 操作編譯到單個 FX 圖中,但它可能無法將所有內容捕捉到一個圖中。

某些圖表斷裂的原因是 TorchDynamo 無法克服的。例如,調用 PyTorch 以外的 C 擴展對於 TorchDynamo 來說是不可見的,並且可能執行任意操作,而 TorchDynamo 無法引入必要的保護措施來確保編譯後的程式可以安全地重複使用。

為了最大化效能,盡可能減少圖表斷裂非常重要。

識別圖表斷裂的原因

要識別程式中的所有圖表斷裂以及與斷裂相關的原因,可以使用 torch._dynamo.explain。此工具會在提供的函式上運行 TorchDynamo,並彙總遇到的圖表斷裂。以下是一個使用範例:

import torch
import torch._dynamo as dynamo
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    print("woo")
    if b.sum() < 0:
        b = b * -1
    return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
    User Stack:
      <FrameSummary file foo.py, line 5 in toy_example>
  Break Reason 2:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
  ...
Out Guards:
  ...
"""

要在遇到的第一個圖表斷裂時引發錯誤,您可以使用 fullgraph=True 停用 python 後備方案,如果您使用過基於導出的編譯器,應該對此很熟悉。

def toy_example(a, b):
   ...

torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)

為什麼我更改程式碼後沒有重新編譯?

如果您通過設定 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py 啟用了動態形狀,那麼您的程式碼將不會在形狀更改時重新編譯。 我們添加了對動態形狀的支援,這樣可以避免在形狀變化小於 2 倍的情況下重新編譯。 這在 CV 中圖像大小變化或 NLP 中可變序列長度等情況下特別有用。 在推理場景中,通常不可能預先知道批次大小,因為您會從不同的客戶端應用程式中獲取可用的批次大小。

一般來說,TorchDynamo 會盡力避免不必要的重新編譯,因此,舉例來說,如果 TorchDynamo 找到 3 個圖,而您的變更只修改了一個圖,那麼只會重新編譯該圖。 因此,避免潛在的緩慢編譯時間的另一個技巧是,先編譯一次模型來預熱它,之後的編譯會快得多。 冷啟動編譯時間仍然是我們可見地追蹤的一個指標。

為什麼我會得到不正確的結果?

如果您設定環境變數 TORCHDYNAMO_REPRO_LEVEL=4,也可以最大限度地減少準確性問題,它以類似的 git bisect 模型運作,完整的重現可能類似於 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4。我們需要這樣做是因為下游編譯器會產生程式碼,無論是 Triton 程式碼還是 C++ 後端,這些下游編譯器的數值可能存在細微的差異,但會對您的訓練穩定性產生巨大的影響。因此,準確性除錯器對於我們檢測程式碼生成或後端編譯器中的錯誤非常有用。

如果您想確保 torch 和 triton 之間的隨機數生成相同,那麼您可以啟用 torch._inductor.config.fallback_random = True

為什麼我會遇到 OOM(記憶體不足)錯誤?

Dynamo 仍然是一個 alpha 產品,因此存在一些 OOM 的來源,如果您看到 OOM,請嘗試按以下順序停用以下配置,然後在 GitHub 上開啟一個 issue,以便我們可以解決根本問題: 1. 如果您正在使用動態形狀,請嘗試停用它們,我們預設已停用它們: env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py 2. 預設情況下,Inductor 中啟用了帶有 Triton 的 CUDA 圖,但移除它們可能會緩解一些 OOM 問題: torch._inductor.config.triton.cudagraphs = False

torch.func 可以與 torch.compile 一起使用嗎?(用於 gradvmap 轉換)

torch.func 轉換應用於使用 torch.compile 的函式是可以的

import torch

@torch.compile
def f(x):
    return torch.sin(x)

def g(x):
    return torch.grad(f)(x)

x = torch.randn(2, 3)
g(x)

在由 torch.compile 處理的函式內部調用 torch.func 轉換

使用 torch.compile 編譯 torch.func.grad

import torch

def wrapper_fn(x):
    return torch.func.grad(lambda x: x.sin().sum())(x)

x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)

使用 torch.compile 編譯 torch.vmap

import torch

def my_fn(x):
    return torch.vmap(lambda x: x.sum(1))(x)

x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)

編譯支援函式以外的函式(緊急出口)

對於其他轉換,作為一種應急方案,可以使用 torch._dynamo.allow_in_graph

allow_in_graph 是一個緊急出口。 如果您的程式碼無法與 torch.compile 一起使用(後者會檢查 Python 位元組碼),但您認為它將透過符號追蹤方法(如 jax.jit)運作,那麼請使用 allow_in_graph

透過使用 allow_in_graph 註釋函式,您必須確保您的程式碼滿足以下要求

  • 函式中的所有輸出僅取決於輸入,而不取決於任何捕獲的張量。

  • 你的函式是純粹的 (functional)。也就是說,它不會修改任何狀態。這個限制可以放寬;實際上,我們支援從外部看起來是純粹的函式:它們可能具有 in-place 的 PyTorch 操作,但不得修改全域狀態或函式的輸入。

  • 你的函式不會引發 data-dependent 的錯誤。

import torch

@torch.compile
def f(x):
    return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)

x = torch.randn(2, 3)
f(x)

一個常見的陷阱是使用 allow_in_graph 來註解一個調用 nn.Module 的函式。這是因為輸出現在取決於 nn.Module 的參數。要使其正常運作,請使用 torch.func.functional_call 來提取模組狀態。

NumPy 能夠與 torch.compile 搭配使用嗎?

從 2.1 開始,torch.compile 能夠理解在 NumPy 陣列上運作的原生 NumPy 程式,以及透過 x.numpy()torch.from_numpy 和相關函式在 PyTorch 與 NumPy 之間轉換的混合 PyTorch-NumPy 程式。

torch.compile 支援哪些 NumPy 功能?

torch.compile 中的 NumPy 遵循 NumPy 2.0 的預先發布版本。

一般來說,torch.compile 能夠追蹤大多數 NumPy 建構,當它無法追蹤時,它會退回到 eager 模式,並讓 NumPy 執行該段程式碼。即便如此,torch.compile 的語義與 NumPy 的語義略有不同的地方。

  • NumPy 纯量:我們將它們建模為 0-D 陣列。也就是說,在 torch.compile 下,np.float32(3) 會傳回一個 0-D 陣列。為了避免圖表斷裂,最好使用此 0-D 陣列。如果這導致你的程式碼崩潰,你可以透過將 NumPy 純量轉換為相關的 Python 純量類型 bool/int/float 來解決此問題。

  • 負步長:np.flip 和使用負步長進行切片會傳回副本。

  • 類型提升:NumPy 的類型提升將在 NumPy 2.0 中發生變化。新的規則在 NEP 50 中描述。torch.compile 實作的是 NEP 50,而不是目前即將被棄用的規則。

  • {tril,triu}_indices_from/{tril,triu}_indices 傳回陣列,而不是陣列的元組。

還有其他一些我們不支持追蹤的功能,我們會優雅地回退到 NumPy 來執行它們

  • 非數值型別,如日期時間、字串、字符、void、結構化型別和記錄陣列。

  • Long 型別 np.float128/np.complex256 和一些無符號型別 np.uint16/np.uint32/np.uint64

  • ndarray 子類別。

  • 遮罩陣列。

  • 深奧的 ufunc 機制,如 axes=[(n,k),(k,m)->(n,m)] 和 ufunc 方法 (例如,np.add.reduce)。

  • 排序 / 排序 complex64/complex128 陣列。

  • NumPy np.poly1dnp.polynomial

  • 具有 2 個或更多傳回值的函式中的位置參數 out1, out2 ( out=tuple 可以運作)。

  • __array_function____array_interface____array_wrap__

  • ndarray.ctypes 屬性。

我可以使用 torch.compile 編譯 NumPy 程式碼嗎?

當然可以!torch.compile 可以原生理解 NumPy 程式碼,並將其視為 PyTorch 程式碼。為此,只需使用 torch.compile 裝飾器包裝 NumPy 程式碼即可。

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

使用環境變數 TORCH_LOGS=output_code 執行此範例,我們可以發現 torch.compile 能夠將乘法和總和融合到一個 C++ 核心中。它也能夠使用 OpenMP 平行執行它們 (原生 NumPy 是單執行緒的)。這可以輕易地使你的 NumPy 程式碼快 n 倍,其中 n 是你處理器中的核心數!

以這種方式追蹤 NumPy 程式碼也支援編譯程式碼中的圖表斷裂。

我可以在 CUDA 上執行 NumPy 程式碼,並透過 torch.compile 計算梯度嗎?

是的,可以!為此,你只需在 torch.device("cuda") 上下文中執行你的程式碼即可。考慮以下範例

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

在這個例子中,numpy_fn 將在 CUDA 中執行。為了使這成為可能,torch.compile 會自動將 XY 從 CPU 移到 CUDA,然後將結果 Z 從 CUDA 移到 CPU。如果我們在同一個程式運行中多次執行這個函式,我們可能希望避免所有這些相當昂貴的記憶體複製。為此,我們只需要調整我們的 numpy_fn,使其接受 cuda 張量並傳回張量。我們可以透過使用 torch.compiler.wrap_numpy 來實現這一點

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在這裡,我們明確地在 CUDA 記憶體中建立張量,並將它們傳遞給函式,函式在 CUDA 裝置上執行所有計算。wrap_numpy 負責將任何 torch.Tensor 輸入標記為在 torch.compile 層具有 np.ndarray 語義的輸入。在編譯器中標記張量是一個非常便宜的操作,因此在執行期間不會發生資料複製或資料移動。

使用這個裝飾器,我們也可以透過 NumPy 程式碼進行微分!

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)))

X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)

我們一直使用 fullgraph=True,因為在此情況下,圖中斷 (graph break) 會產生問題。當發生圖中斷時,我們需要將 NumPy 陣列具體化 (materialize)。由於 NumPy 陣列沒有 devicerequires_grad 的概念,因此這些資訊會在圖中斷期間遺失。

我們無法透過圖中斷傳播梯度,因為圖中斷程式碼可能會執行任意程式碼,而這些程式碼不知道如何進行微分。另一方面,在 CUDA 執行的情況下,我們可以像第一個範例一樣,透過使用 torch.device("cuda") 上下文管理器來解決此問題。

@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    prod = X[:, :, None] * Y[:, None, :]
    print("oops, a graph break!")
    return np.sum(prod, axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")

with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在圖中斷期間,中間張量仍然需要移動到 CPU,但是當在圖中斷後恢復追蹤時,其餘的圖仍然在 CUDA 上追蹤。 考慮到 CUDA <> CPU 和 CPU <> CUDA 之間的移動,圖中斷在 NumPy 環境中相當耗費資源,應該避免使用。但至少它們允許追蹤複雜的程式碼片段。

如何在 torch.compile 下除錯 NumPy 程式碼?

鑑於現代編譯器的複雜性以及它們引發的令人望而生畏的錯誤,除錯 JIT 編譯程式碼具有挑戰性。torch.compile 除錯文件 包含一些關於如何處理此任務的提示和技巧。

如果以上方法不足以找出問題的根源,我們仍然可以使用其他一些特定於 NumPy 的工具。我們可以透過停用對 NumPy 函數的追蹤來判斷錯誤是否完全在 PyTorch 程式碼中。

from torch._dynamo import config
config.trace_numpy = False

如果錯誤存在於追蹤的 NumPy 程式碼中,我們可以透過匯入 import torch._numpy as np,以 eager 模式(沒有 torch.compile)執行 NumPy 程式碼,並將 PyTorch 作為後端。這應該僅用於**除錯目的**,絕不能取代 PyTorch API,因為它的**效能遠遠較差**,並且作為私有 API,**可能會在沒有通知的情況下變更**。 無論如何,torch._numpy 是使用 PyTorch 術語實現的 NumPy 的 Python 版本,並且 torch.compile 在內部使用它將 NumPy 程式碼轉換為 Pytorch 程式碼。它相當容易閱讀和修改,因此如果您發現任何錯誤,請隨時提交 PR 以修復它,或直接開啟一個 issue。

如果程式在匯入 torch._numpy as np 時可以運作,則錯誤很可能在 TorchDynamo 中。 在這種情況下,請隨時開啟一個 issue,並提供一個最小重現範例

torch.compile 一些 NumPy 程式碼,但沒有看到任何加速。

最好的起點是包含除錯此類 torch.compile 問題的一般建議的教學課程

某些圖中斷的發生可能是由於使用了不支援的功能。請參閱torch.compile 支援哪些 NumPy 功能?。 更廣泛地說,請記住一些廣泛使用的 NumPy 功能無法與編譯器很好地協同工作。 例如,原地修改使得在編譯器內進行推理變得困難,並且通常會產生比其異地對應物更差的效能。因此,最好避免它們。 對於使用 out= 參數也是如此。 而是首選異地操作,並讓 torch.compile 優化記憶體使用。 對於資料相關的操作(例如透過布林遮罩進行的遮罩索引),或資料相關的控制流程(例如 ifwhile 建構)也是如此。

哪個 API 用於細粒度追蹤?

在某些情況下,您可能需要從 torch.compile 編譯中排除程式碼的一小部分。本節提供了一些答案,您可以在 TorchDynamo 用於細粒度追蹤的 API 中找到更多資訊。

如何對函數進行圖中斷?

僅對函數進行圖中斷不足以充分表達您希望 PyTorch 執行的操作。 您需要更具體地說明您的使用案例。 您可能需要考慮一些最常見的使用案例

  • 如果您想要停用此函數框架和遞迴調用框架上的編譯,請使用 torch._dynamo.disable

  • 如果您希望特定運算符(例如 fbgemm)使用 eager 模式,請使用 torch._dynamo.disallow_in_graph

一些不常見的使用案例包括

  • 如果您想要停用函數框架上的 TorchDynamo,但重新啟用遞迴調用框架上的 TorchDynamo,請使用 torch._dynamo.disable(recursive=False)

  • 如果您想要防止函數框架的內聯,請在您想要防止內聯的函數開頭使用 torch._dynamo.graph_break

torch._dynamo.disabletorch._dynamo.disallow_in_graph 之間有什麼區別?

Disallow-in-graph 在運算符層級運作,或更具體地說,在您在 TorchDynamo 提取的圖中看到的運算符層級運作。

Disable 在函數框架層級運作,並決定 TorchDynamo 是否應查看函數框架。

torch._dynamo.disabletorch._dynamo_skip 之間有什麼區別?

注意

torch._dynamo_skip 已棄用。

您很可能需要 torch._dynamo.disable。 但是在不太可能的情況下,您可能需要更精細的控制。 假設您只想停用 a_fn 函數上的追蹤,但想要在 aa_fnab_fn 中繼續追蹤。 下圖展示了此使用案例

diagram of torch.compile + disable(a_fn, recursive=False)

在這種情況下,您可以使用 torch._dynamo.disable(recursive=False)。 在先前的版本中,此功能由 torch._dynamo.skip 提供。 現在,torch._dynamo.disable 內的 recursive 標記支援此功能。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源