快捷方式

PyTorch 2.0 NNModule 支援

作者: Will Constable

torch.compile 對於 torch.nn.Module 物件有特殊的處理方式,追蹤它們的方式與追蹤任意 Python 類別不同,目的是透過對結構進行假設來產生更快的程式碼。

此文件描述了由於這種特殊化而出現的一些權衡或邊緣情況。

NNModule 鉤子支援

先前,torch.compile 不支援 nn.Modules 上的鉤子,如果註冊了鉤子,它們將在編譯後的程式中被忽略。事實上,許多使用者根本不使用 nn.Module 鉤子,或者只將它們用於偵錯工作流程,但將 nn.Module 鉤子與 torch.compile 組合在一起是有有效的用例的。

透過 nn.Module.__call__ 實作編排的鉤子包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,將被稱為「呼叫鉤子」。這些鉤子受到 torch.compile 的部分支援,但有如下所述的限制。

另一類鉤子包括 _state_dict_hooks 及其 preload_ 變體,torch.compile 仍然不支援這些鉤子。

nn.Module.__call__ 鉤子的使用和限制

預設情況下,torch.compile 將追蹤 nn.Module.__call__ 的內容,這意味著它會遇到並執行 forward/pre-forward hooks。如果您在呼叫 torch.compile 之前安裝了 hooks,然後之後沒有移除或更改這些 hooks,那麼您的使用案例應該預設受到支援。

一般來說,Backward/Pre-backward hooks 也受到支援,但有一些類似的注意事項:目前,當存取 backward_hooks 字典時,dynamo 中會發生 graph-breaks,這可能可以透過一些工作來避免。 Graph-breaks 也會影響觸發 backward hooks 的時機,因為 graph-segments 會作為 autograd-functions 執行,這些函數會同時產生它們所有的 grads。假設 dynamo 有可能在不因 backward-hooks 的存在而產生 graph-break 的情況下運作,我們仍然會預期一系列模組的 backward hooks 在整個編譯圖的 backward 執行完畢後一起觸發。

‘允許的模組’上的 hooks torch.compile 將常見的模組(例如 torch.conv)以及難以追蹤的模組視為特殊情況,允許它們在 dynamo 圖中以不透明的方式呼叫,而不是由 dynamo 追蹤進去。 對於這些模組,hooks 目前會觸發 graph-break,以便受影響的模組在 dynamo 之外執行。 根據模型,這可能會導致顯著的效能下降,並且需要額外的工作來改進這種支援。

skip_nnmodule_hook_guards 預設情況下,torch._dynamo.config.skip_nnmodule_hook_guards 設定為 True,這表示不會在每個 nn.Module hook 字典上安裝 guards,從而通過減少 guard 執行時間來提高執行時間,但代價是如果編譯後任何 hook 字典發生更改,則不會注意到。

如果您希望在編譯後能夠移除或修改 hooks,並讓 torch.compile 做出適當的反應(透過重新編譯),那麼您需要設定 skip_nnmodule_hook_guards=False 並預期因新增的 guards 而產生的執行時間損失。

TODO:確認 backward/pre_backward hooks 是否正在運作,並據此記錄。

state_dict Hooks

State dict hooks 尚未在 torch.compile 中得到支援。

TODO:如果因 hooks 產生 graph-breaking,則發出 warn_once。 如果存在 hooks,則發出 warn_once 並指向此文檔。

文件

存取 PyTorch 的完整開發人員文檔

檢視文檔

教學

取得適用於初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您問題的解答

檢視資源