functorch.compile.memory_efficient_fusion¶
-
functorch.compile.
memory_efficient_fusion
(fn, static_argnums=None, **kwargs)[source]¶ 封裝
aot_function()
和aot_module()
的函式,以進行記憶體效益佳的融合操作。此函式使用min_cut_rematerialization_partition()
分割器來執行有效率的重新計算。此函式使用 NVFuser 來編譯產生的前進和後退圖形。警告
此 API 為實驗性,可能會有所變更。
- 參數
fn (聯合[可呼叫, nn.Module]) – 一個 Python 函式或
nn.Module
,接受一個或多個引數。必須傳回一個或多個 Tensor。static_argnums (Optional[Tuple[Int]]) – 選擇性的整數元組,用於標記函式的引數為靜態。
**kwargs – 任何其他對設定的覆寫動作
- 傳回值
傳回一個
Callable
或nn.Module
,用於保留原始fn
的急切行為,但其前進和後退圖形已經過重新計算最佳化,而且這些圖形已使用 nvfuser 編譯。