OP 算子降低指南¶
PyTorch 包裝了 C++ ATen 張量函式庫,該函式庫提供了在 GPU 和 CPU 上實作的各種運算。Pytorch/XLA 是 PyTorch 的擴充功能;其目的之一是將 PyTorch 運算轉換為 XLA 運算。「降低」定義了將較高層級的表示法轉換為較低層級表示法的過程。在本文檔中,我將把 PyTorch 運算轉換為 XLA 運算的過程稱為「降低」。XLA 編譯器也會將 XlaOp 降低為 HLO,但這超出了本文檔的範圍。對於我們尚未提供 XLA 降低的運算,我們將轉發到 CPU 並調用 ATen 實作。轉發到 CPU 的運算將導致顯著的效能降低。我們必須降低模型中使用的所有運算,以實現最佳效能。
以下範例說明了對於尚未降低的運算,您可能會從 PyTorch/XLA 偵錯工具中看到的內容
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
開始之前¶
您應該遵循貢獻 Pytorch/XLA中的說明,安裝所需的依賴項,並從原始碼建置 pytorch 和 pytorch/XLA。您不需要存取 TPU 即可實作降低。建議在工作站上進行實驗,並將其設定為使用 XLA:CPU。您可以透過執行以下命令將 Pytorch/XLA 設定為使用 XLA:CPU:
export PJRT_DEVICE=CPU
了解運算¶
您可以在native_functions.yaml中找到 C++ ATen 運算的定義。從原始碼建置 Pytorch/XLA 後,您還可以在 xla/torch_xla/csrc/aten_fallback.h/cpp
中找到我們的預設實作(一個 boxed kernel,它將呼叫轉發到 PyTorch 原生 kernel)。Pytorch 運算通常可以輕鬆地對應到PyTorch 張量 API。如果不是這種情況,建議在PyTorch repo下搜尋 PyTorch 原生實作。目標是將 PyTorch 運算降低為XLA 運算語義中定義的一系列 XLA 運算。
檔案結構¶
以下提及的所有檔案都位於 xla/torch_xla/csrc
資料夾下,但 codegen/xla_native_functions.yaml
除外
xla_native_functions.yaml
包含明確降低的所有運算符(來自核心 Aten 列表)的列表。組合運算符未在此處列出。此處的每個運算符名稱都必須直接匹配native_functions.yaml中列出的 pytorch 運算符。此檔案作為新增 xla 運算符的介面,並且是 PyTorch 的codegen 機制的輸入。它會產生以下 3 個檔案:XLANativeFunctions.h
、RegisterXLA.cpp
和RegisterAutogradXLA.cpp
XLANativeFunctions.h
和aten_xla_type.cpp
是 PyTorch 進入 pytorch_xla 世界的入口點,並且包含針對每個運算符手動編寫的 XLA 降低。XLANativeFunctions.h
是透過xla_native_functions.yaml
和 PyTorch 核心native_functions.yaml
檔案的組合自動產生的,並且包含需要在aten_xla_type.cpp
中定義的 kernel 的宣告。此處編寫的 kernel 需要使用輸入at::Tensor
和其他參數來建構 'XLATensor'。產生的XLATensor
需要在返回 PyTorch 世界之前轉換回at::Tensor
。RegisterXLA.cpp
和RegisterAutogradXLA.cpp
是自動產生的檔案,用於向 PyTorch Dispatcher 註冊所有降低。它們還包括out=
和inplace
運算符的自動產生包裝函式實作。aten_fallback.h/.cpp
包含我們的 boxed fallback 實作。如果未在xla_native_functions.yaml
+aten_xla_type.cpp
中明確定義降低,且運算符不是組合運算符,則將使用 boxed fallback kernel。tensor_methods.h
包含XLATensor
宣告。這些宣告通常是我們在XLANativeFunctions.h
中宣告的at::Tensor
節點的一對一映射。tensor_methods.cpp
包含tensor_methods.h
中定義的XLATensor
節點的實作。我們從參數的ir::Value
建構了對應的ir::op
,並將其包裝在XLATensor
內。Ir 代表中間表示法。ops/
目錄包含所有ir::ops
宣告和定義。較小的節點可以放在ops/ops.h/.cpp
中。較複雜的節點可以放在單獨的檔案中。所有 ops 都繼承自ir::ops::Node
,並提供一種將輸入ir::Value
降低為一系列XlaOp
的方法。
單元測試¶
我們的 CI 會針對每個變更和每天執行 PyTorch 原生 python 測試。如果我們提供降低,這些測試將使用 XLA 實作。除非我們想要驗證一些 xla 行為(例如動態形狀),或者因為某些原因我們跳過了 pytorch 原生測試,否則我們通常不需要為 PyTorch/XLA 新增額外的 python 測試。如果需要,應將 python 測試新增至 xla/test/test_operations.py
。我們還需要在 xla/test/cpp/test_aten_xla_tensor.cpp
中新增 CPP 測試。此測試應調用 PyTorch c++ API,並驗證我們的實作是否產生與 PyTorch 原生實作相同的結果。我們還需要透過檢查 aten::op
和 xla::op
計數器來驗證當張量是 XLA 張量時是否調用了 xla 實作。
提示¶
降低的過程是將 PyTorch 運算分解為一系列 XlaOp。為了提供 PyTorch 運算的良好降低,需要充分掌握 XLA 的功能。閱讀 XlaOp 文件並研究類似的運算如何降低是實現這一目標的最佳方法。您可以在此 Op 降低 PR中找到一個最小的 Op 降低範例。您還可以在此反向降低 PR中找到一個稍微複雜的反向降低範例。
對於 RegisterXLA.cpp
中的某些運算符,我們具有 out=
和 inplace
運算符的自動產生包裝函式實作。在這種情況下,我們只需要降低 vanilla op。一個範例是 lerp
運算符,它在 native_functions.yaml
中有 6 個變體,它們是
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
並且將產生函數原型
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);
如果在我們將所有變體新增至 xla_native_functions.yaml
,則在 XLANativeFunctions.h
中。但是,如果我們僅降低 lerp.Scalar
和 lerp.Tensor
並檢查 RegisterXLA.cpp
,我們將看到
namespace {
at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
// No device check
// DeviceGuard omitted
return torch_xla::lerp(self, end, weight);
}
} // anonymous namespace
at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
at::_copy_from(wrapper_Scalar_lerp__tmp, self);
return self;
}
...
m.impl("lerp_.Scalar",
TORCH_FN(wrapper_Scalar_lerp_));
Codegen 將自動產生使用我們的 lerp.Scalar
實作的 lerp_.Scalar
和 lerp.Scalar_out
的降低,而無需我們提供明確的降低。
一般而言,如果 pytorch 核心中有一個運算符同時具有 out-of-place 和 out= 變體,最好為 out-of-place 變體編寫降低,因為您將免費獲得程式碼產生的 out= 降低。
對於每個節點,我們都需要傳遞一個 ir::OpKind
。這是一個範例 (範例)。您可以在interned_strings.h中找到 OpKind
定義。如果 aten 符號遺失,您可以提交類似此 PR的 PR。