假張量¶
程式碼: fake_tensor.py
動機¶
在進行 Dynamo 符號求值和編譯器傳遞時,我們經常需要能夠執行張量操作,以了解輸出大小/資料類型/裝置是什麼,而無需實際執行這些操作(或丟棄現有的張量),因為這會比較慢(如果您正在進行大量運算)並且會佔用大量記憶體(如果您的編譯器在編譯程式時需要使用 GPU 記憶體,那就糟了)。Fake Tensor 就像一個真實的張量,除了它實際上沒有任何資料。例如,當我們進行 Dynamo 追蹤時,我們需要追蹤使用者的張量程式碼,並回答有關中繼值的一些問題(例如,如果使用者對中繼張量執行條件判斷)。如果沒有 Fake Tensor,我們將無法獲得這些查詢的準確資訊。
類似地,假設您想為張量儲存元資料,例如,在 FX IR 節點上(meta['val'])。您可以直接在節點上儲存 Fake Tensor,這樣就可以獲得您需要的所有張量元資料,包括您可能沒有處理到的細微之處(例如,別名關係)。
整體架構¶
所有 Fake Tensor 都與 FakeTensorMode 關聯。由於 Fake Tensor 的主要用例是對真實張量進行分析,因此一般的 workflow 是您擁有一堆真實張量,您分配一個 FakeTensorMode,然後您使用 from_real_tensor 將所有這些真實張量轉換為 Fake Tensor,然後您對 Fake Tensor 進行操作。特別是,FakeTensorMode 持續維護一個 memo table,將張量(和儲存體)映射到相同的儲存體。如果您多次 fakeify 同一個張量,您將獲得相同的 Fake Tensor;如果您 fakeify 兩個彼此別名的張量,您將獲得兩個別名相同 Fake Storage 的 Fake Tensor。FakeTensor 是張量子類別,因此如果您對它們執行運算,您將自動獲得 Fake Tensor,但一般來說,您會希望在啟動 FakeTensorMode 的情況下對 Fake Tensor 執行運算(例如,如果您正在運行 FX 傳遞);張量運算將自動開啟 Fake Tensor 模式並再次嘗試。
Fake Tensor 表示為 Meta Tensor 的 __torch_dispatch__ 張量子類別。這意味著在底層,Fake Tensor 是 Meta 裝置張量;然後它們使用額外的可擴展性 hooks,特別是 dispatch_device,來謊報張量的實際裝置是什麼。這是早期 Fake Tensor 中比較容易出錯的部分之一:有時,Fake Tensor 在謊報自己是 CPU/CUDA 什麼的方面太過擅長了,最終你會得到一個 CPU kernel 被調用,並帶有一個 Fake Tensor 試圖解引用資料指標,這顯然行不通。如果您在 Fake Tensor 程式碼中遇到 segmentation fault,這是您應該首先檢查的事情:C++ backtrace 是在 CPU kernel 中(意外!)還是 Meta kernel 中(預期!)?Meta kernel 就像一個真實的 kernel,但它所做的只是分配輸出,它不進行任何資料運算。
張量子類別必須定義如何實作各種運算。以下是一般的 Fake Tensor 食譜
在輸入的 Fake Tensor 上運行 Meta kernel,將它們重新解釋為 Meta Tensor。這是透過一個神奇的上下文管理器 in_kernel_invocation_manager 完成的,它指示所有 PyTorch 將 Fake Tensor 視為其底層 Meta Tensor,而不是將 Fake Tensor「解包」為 Meta Tensor(Fake Tensor 是一個 Meta Tensor)。Fake Tensor 以這種方式表示,以避免必須保持兩組元資料同步(Meta Tensor 的元資料和 Fake Tensor 的元資料);「is a」關係確保只有一個規範的元資料副本。
如果您是 factory function,您將改為使用 device='meta' 調用底層的 factory function。
將產生的 Meta Tensor 轉換為 Fake Tensor,計算張量的輸出裝置應該是什麼(這通常是微不足道的,但有時並非如此,例如,CPU 純量提升或裝置轉換運算)。
API:重要的部分¶
非 PT2 用法(請查看 test/test_fake_tensor.py 以獲得更多範例)
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
Q:為什麼您有真實張量作為輸入?
A:在 PT2 上下文中,這是因為您通常正在進行即時編譯,因此對於您正在編譯的圖的所有輸入,您已經有了「真實」輸入,因為您正在執行程式時進行編譯。
PT2 pre-AOTAutograd 用法(這是不尋常的,您可能不想這樣做)
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
detect_fake_mode 將搜尋許多位置,以嘗試找到與生命週期相關聯的「the」 Fake Tensor 模式。通常,它會從追蹤上下文中提取出來。
PT2 post-AOTAutograd 用法
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
其他有用的東西
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
什麼時候您可能想要停用 Fake Tensor 模式?通常您不想這樣做。我們發現它有用的一個小眾案例是在 Fake Tensor 上實作常數傳播:在這種情況下,即使我們處於 Fake Tensor 模式,我們也需要進行一些實際的張量運算。
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
詳細資訊¶
自動轉換還是不轉換?最初,如果您嘗試在 FakeTensorMode 區域內對真實張量進行運算,FakeTensorMode 不會自動 fakeify 真實張量。這樣做的動機是為了防止以下 footgun
with FakeTensorMode():
real_tensor.t_()
這段程式碼應該做什麼?如果我們實際修改了真實張量上的元資料,那將令人驚訝。但同時,也沒有任何明顯的機會創建 FakeTensor。因此,我們保守地決定提出一個錯誤:「在 FakeTensorMode 中使用非 Fake Tensor 輸入調用運算符尚不受支持。請首先將所有張量轉換為 FakeTensor。」
實際上,這個錯誤非常煩人。例如,假設您有一個真實的 nn.Module,並且您想要透過它餵入 Fake Tensor。您需要以某種方式 fakeify nn.Module。這促使了 FakeCopyMode 的出現。
最終,我們放棄了,並加入了自動偽造化 (automatic fakeification)。然而,在許多 FakeTensorMode 的使用情境中,預設仍未啟用此功能。
關於 fake tensor 的 metadata 變更:如果你有一個 fake tensor,並且對其使用 t_() 方法,則 fake tensor 上的 metadata 會發生變更。表面上這是合理的,但有時你也會想將 fake tensor 作為 FX 節點上的 metadata 儲存;變更 fake tensor 是不好的,因為這會使舊的 metadata 無效!
事實上,這裡存在一個根本性的矛盾,即 fake tensor 維護關於 tensor 的極其精確的 metadata,包括物件識別。如果物件 metadata 在 FX 圖中隨著時間而改變,實際上沒有任何方法可以表示這種隨著時間的改變。大多數時候,我們嚴重的 FX 分析都是在功能化的圖 (functionalized graphs) 上完成的,這些圖沒有這個問題,但偶爾你需要在非功能化的圖上進行分析。也許將 fake tensor 放入 meta['val'] 是一個錯誤。
關於 tensor 子類別¶
Fake tensor 使用子類別和模式 tensor 子類別的模式,其中 FakeTensor.__torch_dispatch__ 啟用與 fake tensor 相關聯的 FakeTensorMode,然後重新分派 (re-dispatches) (依賴 FakeTensorMode 來完成繁重的工作)。如果 fake tensor 操作獲得了它無法識別的子類別參數,它將返回 NotImplemented,讓另一個子類別有機會首先運行 (希望能夠去糖化 (desugaring) 為普通的 tensor 操作),然後再嘗試一次。這可能會導致無限迴圈。
每個個別的運算子是如何實作的?¶
不幸的是,任何給定的運算子都可能在一組非常複雜的地方實作。需要了解的一些重要情況:
Tensor 子類別支援有限的常數傳播 (constant propagation),如果元素的數量非常小 (這有助於處理我們立即在這些 tensor 上呼叫 item() 的一些情況)。
我們為某些運算子提供了一些快速路徑實作 (fastpath implementations),這些實作完全在 fake tensor 中完成,以提高效能。
如果你使用 @custom_op 產生一個自定義 tensor,這些實作將直接將 impl_abstract 註冊到 fake tensor。
Fake tensor 本身對於 device 轉換操作有一些硬編碼的特殊情況。
如果沒有 meta 實作,也沒有任何分解 (decomposition),我們將產生真實的、零填充的 tensor,並嘗試直接運行該運算子,以找出結果會是什麼。如果該運算子嘗試使用資料進行索引,這可能會導致分段錯誤 (segfaults),因此我們預設不會為自定義操作啟用此功能。
轉換器如何運作?¶
由於 fake tensor 用於對 tensor 的確切屬性非常敏感的情況,因此 fake tensor 會非常小心地進行轉換,保留 leaf-ness、requires_grad'ness、aliasing 和一大堆其他屬性。大部分繁重的工作都在 MetaConverter 中完成。
效能特徵¶
你可能會認為 fake tensor 速度很快,因為它們不做任何 tensor 計算。但在小的 tensor 大小上,我們實際上完全受 overhead 的限制,而且,嗯,fake tensor 是用 Python 編寫的,我們經常做大量的工作來執行單個 tensor 操作 (因為它們是作為分解來實作的)。因此,fake tensor 在實踐中實際上非常慢,尤其是在涉及符號形狀 (symbolic shapes) 時。我們目前在 fake tensor 中有兩個重要的快速路徑,它們在實踐中產生了很大的影響:
Pointwise 操作不會通過 PrimTorch decomps,而是我們手動編碼了它們的傳播規則。
如果可能的話,我們應該這樣做。
Fake tensor of fake tensor?¶
人們有興趣將 fake tensor 作為用戶輸入發送到 PT2 堆疊中,這意味著我們需要能夠創建 fake tensor of a fake tensor。目前這並未真正支援,但或許做起來不會太難。
與動態形狀的互動¶
每個 FakeTensorMode 都包含一個 ShapeEnv,它追蹤所有符號形狀資訊。它們的生命週期通常是綁定的:它們一起生,一起死。
由於 FakeTensorMode 有一個 ShapeEnv (但 meta 實作沒有),因此數據相關且需要分配未備份的 SymInt 的 meta 函數存在於 fake tensor 中。Fake tensor 也負責記憶 (memoizing) 未備份的 SymInts,因此,例如,如果你在同一個 fake tensor 上呼叫 nonzero() 兩次,你會得到相同符號大小。