捷徑

functorch.compile.aot_function

functorch.compile.aot_function(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, num_params_buffers=0, hasher_type=None, static_argnums=None, keep_inference_input_mutations=False)[來源]

使用 Torch 的調度機制追蹤 fn 的正向和反向圖,然後通過 fw_compilerbw_compiler 編譯生成的正向和反向圖。

aot_function() 會預先追蹤正向和反向圖,並生成一個聯合的正向和反向圖。然後使用 partition_fn 將正向和反向圖分開。分割函數可用於執行諸如重新計算之類的優化。可以設置 decompositions 字典,將運算子分解為後端編譯器支援的核心或更簡單運算子的序列。

aot_function() 使用基於輸入張量屬性的編譯快取來檢測何時需要重新編譯。

警告

此 API 仍處於實驗階段,可能會有所變動。

參數
  • fn (Callable) – 接受一個或多個參數的 Python 函數。必須返回一個或多個張量。

  • fw_compiler (Callable) – 接受包含 Aten 運算和輸入參數的 Fx 圖的 Python 函數,並返回一個 Callable,其語義等效於輸入的 Fx 圖。

  • bw_compiler (Optional[Callable]) – 接受包含 Aten 運算和輸入參數的 Fx 圖的 Python 函數,並返回一個 Callable,其語義等效於輸入的 Fx 圖。預設值:None(當為 None 時,預設為 fw_compiler

  • partition_fn (Callable) – 接受聯合的正向和反向圖,並將其分割為單獨的正向和反向圖的 Python 函數。

  • decompositions (Dict) – 定義將較大的 Aten 運算分解為更簡單或核心 Aten 運算的字典。

返回值

返回一個 Callable,保留原始 fn 的 Eager 行為,但正向和反向圖已通過 fw_compilebw_compile 編譯。

aot_function() 的一個簡單用法示例如下。此示例將打印函數 fn 的正向和反向圖。

>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>>     print(fx_module)
>>>     return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源