functorch.compile.aot_function¶
-
functorch.compile.
aot_function
(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, num_params_buffers=0, hasher_type=None, static_argnums=None, keep_inference_input_mutations=False)[來源]¶ 使用 Torch 的調度機制追蹤
fn
的正向和反向圖,然後通過fw_compiler
和bw_compiler
編譯生成的正向和反向圖。aot_function()
會預先追蹤正向和反向圖,並生成一個聯合的正向和反向圖。然後使用partition_fn
將正向和反向圖分開。分割函數可用於執行諸如重新計算之類的優化。可以設置 decompositions 字典,將運算子分解為後端編譯器支援的核心或更簡單運算子的序列。aot_function()
使用基於輸入張量屬性的編譯快取來檢測何時需要重新編譯。警告
此 API 仍處於實驗階段,可能會有所變動。
- 參數
fn (Callable) – 接受一個或多個參數的 Python 函數。必須返回一個或多個張量。
fw_compiler (Callable) – 接受包含 Aten 運算和輸入參數的 Fx 圖的 Python 函數,並返回一個 Callable,其語義等效於輸入的 Fx 圖。
bw_compiler (Optional[Callable]) – 接受包含 Aten 運算和輸入參數的 Fx 圖的 Python 函數,並返回一個 Callable,其語義等效於輸入的 Fx 圖。預設值:None(當為 None 時,預設為
fw_compiler
)partition_fn (Callable) – 接受聯合的正向和反向圖,並將其分割為單獨的正向和反向圖的 Python 函數。
decompositions (Dict) – 定義將較大的 Aten 運算分解為更簡單或核心 Aten 運算的字典。
- 返回值
返回一個
Callable
,保留原始fn
的 Eager 行為,但正向和反向圖已通過fw_compile
和bw_compile
編譯。
aot_function()
的一個簡單用法示例如下。此示例將打印函數fn
的正向和反向圖。>>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x)