• 文件 >
  • Pytorch/XLA 中重新編譯的來源
捷徑

Pytorch/XLA 中重新編譯的來源

讓我們先從一些事實/限制開始:

  1. XLA 中的圖形編譯相當耗費資源。

  2. XLA 僅處理靜態形狀。換句話說,即使是相同的 IR 圖形,當輸入形狀改變時,XLA 也會重新編譯。

  3. 當重新編譯發生時,會嚴重影響 torch_xla 的效能,且對於一般的 Python 使用者來說,難以理解和除錯。

通常當重新編譯發生時,我們會說我們只需要動態形狀支援,然後確信當未來支援動態形狀時,所有的重新編譯都會神奇地消失。但事實並非如此,XLA 現在已經有相當不錯的有限動態形狀覆蓋率,但我們仍然看到重新編譯,而且這是預期的。

本文旨在詳細解釋一些常見的重新編譯來源,以及我們需要做些什麼來擺脫它們。它將主要著重於向沒有任何背景知識的初學者解釋問題。為了方便理解,這裡提出的「解決方案」可能基於不切實際的假設。

#1. 來自輸入資料集。

是的,輸入資料集包含不同形狀的範例非常常見,例如長度不同的句子或大小不同的圖片。若不進行正規化,每次新的輸入形狀都會導致重新編譯。

Tensorflow 圖形模式使用者更習慣使用 padding/bucketization (tf.pad) 將輸入形狀正規化為一個或幾個 bucket。但這對於 PyTorch eager 前端使用者(也是 lazy tensor 前端試圖鎖定的使用者)來說有點反模式,因為不同的輸入形狀對於 eager CPU/CUDA 後端來說根本不重要。

建議的解決方案: 好的,現在假設我們可以透過教導使用者進行 padding/bucketization 來解決這個問題(實際上很難 :P)。接下來呢?

#2. 來自運算子輸出

有些運算子在語義上是資料相依的,並產生動態形狀輸出:例如,torch.nonzero 會傳回輸入張量中非零元素的索引。因此,即使您輸入到此運算子的張量始終具有相同的形狀,它也可能會產生不同的形狀輸出並導致重新編譯。

2.1 當您將具有動態形狀的張量作為張量使用,而不查詢其實際維度時,有限動態形狀可以解決此問題。

建議的解決方案: 假設現在 XLA 支援所有運算子的有限動態形狀,這樣就夠了嗎?

  • 有限動態形狀表示我們可以將張量填充到理論最大值,以犧牲更多記憶體使用量來換取更少的重新編譯/更快的速度。

嗯,某種程度上是。讓我們看看以下範例

a = torch.tensor([1, 2, 0, 1, 3], device='xla')
b = torch.nonzero(a)
c = b * 2
d = c + 1
print(torch_xla._XLAC._get_xla_tensors_text([d]))

在上面的範例中,圖形中 b 下方的每個節點(即 c, d 和所有依賴它們的節點)都將具有動態形狀,很明顯 b 在維度 0 中具有動態形狀,如下所示

%9 = (s64[<=5,1]{1,0}, s64[]) aten::nonzero(%8), num_outputs=2 # b
%10 = s64[5,1]{1,0} aten::mul(%9.0, %3) # c
%11 = s64[5,1]{1,0} aten::add(%10, %2), ROOT=0 # d

雖然在圖形中沒有直接顯示,但 c & d 實際上也具有動態形狀(換句話說,[5, 1] 只是填充形狀,並且被遮罩)。

print(torch_xla._XLAC._get_xla_tensor_dimension_size(d, 0)) # prints 4 instead of 5

您可以看到,在這種情況下,只要輸入張量 a 的形狀為 [5],我們就只會編譯圖形一次。有限動態形狀支援有所幫助!

2.2 如果在具有動態形狀的張量上查詢實際維度會怎樣?

這實際上非常常用,因為並非所有 PyTorch 計算都以張量的形式完成。

例如,PyTorch 中的 tensor.size() 會傳回整數的 tuple,而不是 dtype=int 的張量。當 tensor 是動態形狀張量時,此運算基本上會強制 XLA 切割圖形並進行評估,以便我們可以傳回正確的純量(否則它只會傳回錯誤的填充形狀)。

更糟糕的是,許多 PyTorch 也接受純量輸入。在您執行 s = tensor.size(0) 並在其他運算子中使用 s 之後,它也會變成動態來源。在這種情況下,我們可能知道如何填充它及其上限,但我們無法做到,因為它甚至不是張量!

a = torch.tensor([1, 2, 0, 1, 3], device='xla')
b = torch.nonzero(a)
s = a.size(0) # evaluation happens! nit: we use size() for simplicity, the actual API is _get_xla_tensor_dimension_size.
c = torch.rand(s, device='xla') # c can be of any shape between [0, 5] which causes more recompilations!
d = c + 1

所以這個問題實際上很難在沒有 PyTorch 前端協助的情況下解決。我們需要什麼?

簡而言之,我們需要一個張量世界!

例如,

  • tensor.size() 應該傳回一個張量,以便它可以是一個具有動態形狀的張量,並保留在圖形中而無需提前評估。

  • 張量存取器,例如對於 2D 張量,tensor[0][0] 現在傳回一個值,但這也需要傳回一個張量。

  • 隱含地,這意味著目前所有以 int/float/double 作為輸入的運算子也需要張量重載。這是一個很大的要求,因為它很容易使我們的運算子集合爆炸。

    • 如果我們可以讓純量到張量的轉換非常便宜,這樣我們就可以只關心張量重載,那就更容易了。

    • 實際上,並非所有運算都從先前的計算中取得純量,因此我們一直透過臨時請求添加張量變體。

    • 我認為這也是基於追蹤方法的一個常見要求。

好的,現在我們假設 PyTorch 中的每個運算都有我們需要的張量版本,這樣就完成了嗎?

#3. 來自控制流程

不!我們實際上只解決了沒有資料相依控制流程的問題…

請參閱以下範例

if x[0][0] == 3:
  bla
else:
  blabla

即使 x[0][0] 是張量,我們也需要執行/實質化它的值,python 解釋器才能繼續執行。而多個控制流程中不同的分支選擇組合意味著我們也需要編譯大量的圖形!

目前我們還沒有辦法解決這個問題。要解決它,我們需要將控制流程從 python 降低到圖形!在沒有過多思考實作的情況下,我們可以透過兩種方式做到這一點

  • 要求使用者明確使用控制流程運算子來取代 python if/else/while/for。目前在 torch_xla 中以自訂 API 支援,但尚未在使用者程式碼中廣泛採用。(python 使用者習慣了 if/else/for,除非有巨大的效能提升,否則很難讓他們切換到更醜陋的 API)。

  • 解析 python 原始碼。程式碼以自動取得控制流程語句。這就像 Torchscript,並以某種方式將 torchscripted 圖形正確地合併到懶惰追蹤圖形中(包括形狀資訊等)。我確實還沒有想清楚如何實作這個步驟 :P

但是以上兩種解決方案都需要相當大的努力,無論是在使用者端還是框架端。這就是為什麼我們目前僅將早期評估和多次編譯作為短期解決方案,因為我們目前的頻寬有限。

好的,現在我們假設控制流程也自動降低到圖形中,這樣我們就成功了嗎?

是的!現在您的整個計算都以張量運算的圖形表示,包括控制流程,以便編譯器現在可以取用並執行它們的智慧技巧!但老實說,在這種情況下,您的程式已經不再那麼 PyTorch-y 了。

結論:

實際上,重新編譯有多個來源,而有限動態形狀支援無法解決所有問題。本文中提出的解決方案有時肯定是不切實際的,並且可能存在更好的方法來正確解決每個來源,而我完全沒有意識到。但我希望隨著我們在本文中不斷努力邁向理想的 lazy tensor 堆疊,現在您更容易理解我們前方剩餘的阻礙是什麼。

附錄:

  1. NNC 使用符號形狀,這有幫助嗎?

是的,但只是部分。透過使用符號形狀,您的編譯最佳化不再需要具體的形狀值。換句話說,您產生的核心比 XLA 的靜態形狀核心更通用。

這到底有助於解決哪個問題?

它有助於解決像 #1 和 #2.1 這樣的案例。

shape [3, 5] -> add -> transpose -> ... -> mul
shape [6, 2] -> add -> transpose -> ... -> mul

# with symbolic shape
shape [x, y] -> add -> transpose -> ... -> mul

使用符號形狀,您產生的核心不會像 XLA 使用靜態形狀那樣重新編譯。

XLA 以另一種方式解決這個問題,透過使用 padding/bucketization(對於 #1)和有限動態形狀(對於 #2.1)。

Brian Hirsh(@bdhirsh) 在評論中提出了一些非常好的問題,移到這裡讓它們更顯眼

  1. 是否值得在產生資料相依輸出形狀的運算子的 XLA 核心中加入 TORCH_WARN?

是的,torch_warn 有助於告訴使用者「嘿,你的程式不會跑得飛快」。但對於這些資料相依的運算子,除非使用者更改其模型中的邏輯,否則沒有簡單的重寫方法。(另一個例子是 torch.unique())

  1. 像 nonzero 這樣的運算子如何影響我們 devirtualize sizes() 的能力?如果我們想要 devirtualize sizes(),我們需要能夠 eager 地計算每個運算子的 size - 這是否意味著每次我們遇到像 nonzero 這樣的運算子時,我們都被迫評估圖形?與現在相比,聽起來我們在使用者呼叫 nonzero() 時實際上並沒有強制評估?

是的,好問題!所以在目前的形式下,這不是一個硬性阻礙,因為 XLA 張量上的 size() 不會攜帶真實尺寸資訊來源。如範例所示,真實尺寸來源存在於 IRValue 中,並且只能透過 _get_xla_tensor_dimension_size 擷取。因此,如果我們決定 devirtualize size,它只會強制執行這種差異。

作為後續,如果我們讓 size() 傳回張量而不是值,如上面建議的解決方案中所述。在這種情況下,size() 將無法 devirtualize,因為它變成了一個運算子(輸入張量並產生張量,對於不同的後端有不同的實作)。

  1. 例如,如果我在迴圈中呼叫 torch.add(input, 1),其中輸入的大小從 1-1000 不等,通常我們必須編譯 1000 個不同的圖形 - 但使用動態形狀,聽起來 XLA 將能夠在內部產生一個單一圖形,其中它說「如果輸入大小 <=1000,則使用此圖形」。我的問題是:「動態形狀」僅僅是圖形的屬性嗎?還是圖形和輸入兩者的屬性。也就是說,如果我的程式碼改為在迴圈中呼叫 x = torch.add(input, 1); x.sizes(),那麼此時 x 是否具有動態形狀,這意味著我們需要執行圖形才能取得大小?或者即使在存在具有動態形狀的圖形的情況下,我們也能使其成為 eager 計算的屬性。

是的,在這種情況下,您將編譯 1000 個不同的圖形。動態形狀意味著它的輸入具有動態維度。因此,當您查詢 x.sizes()(目前需要使用 get_dimention_size 才能取得正確的大小)時,它將觸發*執行*(由於大小沒有改變,因此不會觸發重新編譯)。如果沒有存取大小的行,當輸入具有動態維度時,它不會觸發任何重新編譯/執行。

  1. 讓控制流程在圖形中可用的替代方案是否只是想出一個方法來確保 XLA 圖形不包含控制流程?也就是說,如果我們有一個中間帶有單個條件的模型,那麼讓 XLA 產生 3 個圖形:1 個用於條件之前的所有內容,1 個用於 if 分支,以及 1 個用於 else 分支。這意味著您不會因為每個路徑組合而獲得新圖形的指數級爆炸,但 (a) 圖形更小,提供的最佳化機會更少,並且 (b) 讓 XLA 識別條件路徑在哪裡可能非常困難。

好主意!所以如果我們可以將它們分解成更小的圖形,那確實是可行的。但在實務中,這種模式很煩人

y = <some computation>
x = y + 2
if x[0] == 2 :
  z = y +1
else:
  z = y - 1

請注意,當您遇到控制流程時,您將使用子圖形評估 x,但分支計算中也可能包含先前的變數(例如 y 只比 x 小一個節點,但在您評估 x 時它沒有被實質化)。因此,對於這個範例,您實際上正在評估 1 個小圖形和兩個大圖形。並且隨著更多控制流程的加入,y 可能會在多個分支中更新,這仍然會產生不同的大圖形組合。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得解答

檢視資源