快捷方式

動態形狀

程式碼:symbolic_shapes.py

另請參閱:動態形狀手冊

動機

深度學習編譯器通常僅適用於靜態形狀,也就是說,它們產生的編譯程式僅適用於輸入形狀的單個特定配置,如果任何輸入形狀發生變化,則必須重新編譯。這種假設對於當今大多數常見的深度學習模型來說效果很好,但在某些情況下是不夠的

  • 某些維度(例如批次大小或序列長度)可能會有所不同。例如,執行自適應批次的推論服務將根據其批次視窗內收到的請求數量,以不同的批次大小執行推論請求。我們可能還想僅將可變大小的序列填充到批次中的最大序列長度,這可能因批次而異。

  • 有些模型會展現出資料相關的輸出形狀,也就是說,它們輸出和中間變數的大小可能會取決於實際的輸入資料,而輸入資料在不同的執行中可能會有所不同。例如,偵測模型可能會先產生數量不定的潛在邊界框,然後再執行更耗費資源的影像辨識模型,以辨識主體是否在邊界框內。邊界框的數量取決於資料。

  • 資料相關形狀一個特別重要的例子發生在處理稀疏表示法時,例如稀疏張量、鋸齒張量和圖神經網路。在所有這些情況下,要處理的資料量取決於問題的稀疏結構,而稀疏結構通常會以資料相關的方式變化。

在支援動態形狀時,我們選擇不支援動態秩程式,例如,輸入張量的維度會變化的程式,因為這種模式在真實世界的深度學習程式中很少發生,而且可以避免對形狀的符號列表進行歸納推理的需求。

精簡後的公開 API

PyTorch 2.1 中的預設動態行為是

  • PT2 預設會假設所有內容都是靜態的

  • 如果我們因為大小改變而重新編譯,我們會嘗試將該大小重新編譯為動態的(已經改變的大小很可能在未來會改變)。這種廣義化可能會失敗(例如,因為使用者程式碼在有問題的大小上執行條件分支,或 PT2 中缺少對動態形狀的支援)。如果您想了解為什麼 PT2 過度專業化了一些程式碼,請使用 TORCH_LOGS=dynamic 執行,並尋找顯示何時以及為什麼新增保護(guards)的 "eval" 條目。

  • 如果您事先知道某件事會是動態的,您可以使用 torch._dynamo.mark_dynamic(tensor, dim) 跳過第一次重新編譯。如果您事先知道此維度的 minmax 值,您可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果您說 torch.compile(dynamic=False),我們將關閉重新編譯時的自動動態形狀,並且始終針對每個不同的大小重新編譯。相反地,如果您說 torch.compile(dynamic=True),我們將嘗試使所有內容盡可能地動態化。這主要適用於小型運算子;如果您嘗試在大型模型上使用它,(1) 可能會導致 PT2 崩潰,並且 (2) 運行速度會很慢,沒有任何好處。

保護模型

在考慮如何將動態形狀支援添加到 TorchDynamo 和 TorchInductor 時,我們做出了一個重要的設計決策:為了重複使用 Python/C++ 中針對 PyTorch API 編寫的分解和其他現有程式碼,我們必須能夠追蹤動態形狀。與可能捕獲條件分支的兩個分支的完全符號系統不同,我們總是選擇一個分支並專門化我們的追蹤,假設我們僅在未來對該分支做出相同的選擇時才使用此追蹤。為此,我們為每個符號大小保留一個「提示」,說明其在編譯時的具體值(由於 TorchDynamo 是一個即時編譯器,因此它始終知道實際的輸入大小)。當我們對張量執行條件判斷時,我們只需查閱提示即可找出要採用的分支。

這大大簡化了我們產生的符號形狀公式,但也意味著我們有一個更複雜的系統來管理保護。例如,考慮以下程式:

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我們將使用 TorchInductor 編譯的最終 IR 要么是 torch.cat([x, y]).add(2),要么是 torch.cat([x, y]).mul(2)(條件已扁平化),但是要確定我們在哪個分支中,我們需要知道中間變數 z 的大小。因為 TorchDynamo 必須預先知道已編譯的追蹤是否有效(我們不支援像某些 JIT 編譯器那樣的保釋 (bailout)),我們必須能夠將 z.size(0) 作為輸入 x.size(0) + y.size(0) 的表達式來簡化。這是通過為 PyTorch 中的所有運算子編寫元函數來完成的,這些元函數可以將大小資訊傳播到張量的輸出,而無需實際在節點上執行計算。

整體架構

符號形狀工作流程

  1. 當我們開始在 Dynamo 中編譯幀時,我們會分配一個 ShapeEnv(附加到 FakeTensorMode),它會追蹤符號形狀狀態。

  2. 我們在輸入時為張量分配符號大小(靜態或動態是一個策略決策,帶有一些旋鈕)。

  3. 我們通過運算子傳播符號大小,同時維護 (1) FX IR,以便我們可以忠實地匯出符號計算,以及 (2) 表示大小變數的 Sympy 表達式,以便我們可以對它們進行推理。

  4. 當我們在 Dynamo 追蹤或 Inductor 優化中對符號大小進行條件判斷時,我們會根據條件新增保護。這些可以從 Python 和 C++ 中推導出來。

  5. 這些保護可以進一步簡化符號變數。例如,如果您斷言 s0 == 4,我們現在可以用 4 替換所有出現的 s0

  6. 當我們完成追蹤和優化後,我們會將所有這些保護與編譯後的程式碼一起安裝;只有當所有保護都評估為 true 時,編譯後的程式碼才能重複使用。

重要檔案

  • C++ SymInt API: c10/core/SymInt.h, SymFloat.h, SymBool.h

  • Python SymInt API: torch/__init__.py (尋找 SymInt/SymFloat/SymBool)

  • C++ 管道: c10/core/SymNodeImpl.h, torch/csrc/utils/python_symnode.h, torch/csrc/jit/python/init.cpp

  • Python 基礎架構: torch/fx/experimental/symbolic_shapes.py

  • 其他重要檔案: torch/_subclasses/fake_tensor.py, torch/_meta_registrations.py, decomps, PrimTorch refs

精簡後的內部 API

了解 Python 類別層次結構

  • SymInt/SymFloat/SymBool:這些是用戶可見的類別,它們模擬其 int/float/bool 對應項。如果您新增兩個 SymInt,我們會為您提供一個新的 SymInt,它以符號方式追蹤已發生的整數加法。

  • SymNode:這是內部結構(可透過例如 symint.node 存取),用於保存實際的符號追蹤資訊。SymNode 是類型抹除 (type erased) 的;這使得表示混合類型操作更方便。請注意,技術上您不必從 SymInt 呼叫 Python SymNode;例如,XLA 的 C++ SymNodeImpl 將取代 SymNode。

  • ShapeEnv:每個編譯上下文的狀態,用於追蹤到目前為止我們累積的所有自由符號和保護 (guard)。每個 SymNode 都記錄其 ShapeEnv(但反之則不然;SymNode 只有在參與保護時才被使用)。

C++ 非常相似

  • c10::SymInt/SymFloat/SymBool:使用者可見的類別,用於模擬 int/float/bool。

  • c10::SymNode/SymNodeImpl:類似於 SymNode

  • C++ 中沒有 ShapeEnv;為了便於除錯,整個符號推理裝置都在 Python 中。

當您編寫可使用 make_fx 追蹤的程式碼時,它必須能夠處理流經它的 SymInt/SymFloat/SymBool。 動態形狀手冊 提供了一些關於如何做到這一點的指導。

DimDynamic 策略

符號推理

  • 值範圍

  • Sympy 使用注意事項

  • 約束

  • DimDynamic/Constraint

未支持 (Unbacked) 的 SymInts

為了解析控制流程,我們會檢查符號整數的提示 (hint),也就是實際值,以確定要走哪個分支。但是,在某些情況下,我們可能沒有提示:當大小變數從資料相關的操作(如 .nonzero().item())中產生時,就會出現所謂的未支持符號整數。對這些符號整數執行控制流程是非法的,因此我們必須在這些操作上進行圖形中斷 (graph break)。

如果天真地實作,這限制太多了:如果您嘗試對未支持的符號整數做任何事情,大多數 PyTorch 程式會立即失敗。以下是使它真正運作的最重要的增強功能

  • 在張量建立時,PyTorch 會預先計算有關張量的大量資料;例如,如果您使用 empty_strided 建立張量,我們將會急切地對 stride 進行排序,並確定張量是否是不重疊且密集的。排序會產生大量的保護。然而,更常見的是使用更高等級的 API(如 empty)直接產生張量,這保證會產生一個不重疊且密集的張量。我們修改了 PyTorch 以避免不必要地重新計算這些屬性。

  • 即使需要進行重要的計算,有時也根本不會查詢屬性。使這些預先計算的屬性變得延遲 (lazy) 讓我們可以避免對未支持的符號整數進行保護,除非實際需要它。

  • 整數張量中的資料通常未知是否為非負數。但是,我們提供了一個 API constrain_range,使用者可以透過它指定大小由已知的限制上下界定。

在 PT2 的未來版本(超越 PT2.1)中,我們將擴展我們的推理系統,根據使用情況推斷未支持的符號整數是類似於大小的。例如,如果您將 .item() 呼叫的結果傳遞給像 torch.empty 這樣的工廠函數,我們會自動推斷結果是一個大小(因為如果不是,它就會失敗。)此假設將在執行時進行驗證,如果未滿足,則會引發錯誤。

文件

存取 PyTorch 的綜合開發者文件

查看文件

教學

獲取針對初學者和高級開發者的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源