• 文件 >
  • OP 算子降低指南
快捷鍵

OP 算子降低指南

PyTorch 包裝了 C++ ATen 張量函式庫,該函式庫提供了在 GPU 和 CPU 上實作的各種運算。Pytorch/XLA 是 PyTorch 的擴充功能;其目的之一是將 PyTorch 運算轉換為 XLA 運算。「降低」定義了將較高層級的表示法轉換為較低層級表示法的過程。在本文檔中,我將把 PyTorch 運算轉換為 XLA 運算的過程稱為「降低」。XLA 編譯器也會將 XlaOp 降低為 HLO,但這超出了本文檔的範圍。對於我們尚未提供 XLA 降低的運算,我們將轉發到 CPU 並調用 ATen 實作。轉發到 CPU 的運算將導致顯著的效能降低。我們必須降低模型中使用的所有運算,以實現最佳效能。

以下範例說明了對於尚未降低的運算,您可能會從 PyTorch/XLA 偵錯工具中看到的內容

pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward,  Please open a GitHub issue with the above op lowering requests.

開始之前

您應該遵循貢獻 Pytorch/XLA中的說明,安裝所需的依賴項,並從原始碼建置 pytorch 和 pytorch/XLA。您不需要存取 TPU 即可實作降低。建議在工作站上進行實驗,並將其設定為使用 XLA:CPU。您可以透過執行以下命令將 Pytorch/XLA 設定為使用 XLA:CPU:

export PJRT_DEVICE=CPU

了解運算

您可以在native_functions.yaml中找到 C++ ATen 運算的定義。從原始碼建置 Pytorch/XLA 後,您還可以在 xla/torch_xla/csrc/aten_fallback.h/cpp 中找到我們的預設實作(一個 boxed kernel,它將呼叫轉發到 PyTorch 原生 kernel)。Pytorch 運算通常可以輕鬆地對應到PyTorch 張量 API。如果不是這種情況,建議在PyTorch repo下搜尋 PyTorch 原生實作。目標是將 PyTorch 運算降低為XLA 運算語義中定義的一系列 XLA 運算。

檔案結構

以下提及的所有檔案都位於 xla/torch_xla/csrc 資料夾下,但 codegen/xla_native_functions.yaml 除外

  1. xla_native_functions.yaml 包含明確降低的所有運算符(來自核心 Aten 列表)的列表。組合運算符未在此處列出。此處的每個運算符名稱都必須直接匹配native_functions.yaml中列出的 pytorch 運算符。此檔案作為新增 xla 運算符的介面,並且是 PyTorch 的codegen 機制的輸入。它會產生以下 3 個檔案:XLANativeFunctions.hRegisterXLA.cppRegisterAutogradXLA.cpp

  2. XLANativeFunctions.haten_xla_type.cpp 是 PyTorch 進入 pytorch_xla 世界的入口點,並且包含針對每個運算符手動編寫的 XLA 降低。XLANativeFunctions.h 是透過 xla_native_functions.yaml 和 PyTorch 核心 native_functions.yaml 檔案的組合自動產生的,並且包含需要在 aten_xla_type.cpp 中定義的 kernel 的宣告。此處編寫的 kernel 需要使用輸入 at::Tensor 和其他參數來建構 'XLATensor'。產生的 XLATensor 需要在返回 PyTorch 世界之前轉換回 at::Tensor

  3. RegisterXLA.cppRegisterAutogradXLA.cpp 是自動產生的檔案,用於向 PyTorch Dispatcher 註冊所有降低。它們還包括 out=inplace 運算符的自動產生包裝函式實作。

  4. aten_fallback.h/.cpp 包含我們的 boxed fallback 實作。如果未在 xla_native_functions.yaml + aten_xla_type.cpp 中明確定義降低,且運算符不是組合運算符,則將使用 boxed fallback kernel。

  5. tensor_methods.h 包含 XLATensor 宣告。這些宣告通常是我們在 XLANativeFunctions.h 中宣告的 at::Tensor 節點的一對一映射。

  6. tensor_methods.cpp 包含 tensor_methods.h 中定義的 XLATensor 節點的實作。我們從參數的 ir::Value 建構了對應的 ir::op,並將其包裝在 XLATensor 內。Ir 代表中間表示法。

  7. ops/ 目錄包含所有 ir::ops 宣告和定義。較小的節點可以放在 ops/ops.h/.cpp 中。較複雜的節點可以放在單獨的檔案中。所有 ops 都繼承自 ir::ops::Node,並提供一種將輸入 ir::Value 降低為一系列 XlaOp 的方法。

單元測試

我們的 CI 會針對每個變更和每天執行 PyTorch 原生 python 測試。如果我們提供降低,這些測試將使用 XLA 實作。除非我們想要驗證一些 xla 行為(例如動態形狀),或者因為某些原因我們跳過了 pytorch 原生測試,否則我們通常不需要為 PyTorch/XLA 新增額外的 python 測試。如果需要,應將 python 測試新增至 xla/test/test_operations.py。我們還需要在 xla/test/cpp/test_aten_xla_tensor.cpp 中新增 CPP 測試。此測試應調用 PyTorch c++ API,並驗證我們的實作是否產生與 PyTorch 原生實作相同的結果。我們還需要透過檢查 aten::opxla::op 計數器來驗證當張量是 XLA 張量時是否調用了 xla 實作。

提示

降低的過程是將 PyTorch 運算分解為一系列 XlaOp。為了提供 PyTorch 運算的良好降低,需要充分掌握 XLA 的功能。閱讀 XlaOp 文件並研究類似的運算如何降低是實現這一目標的最佳方法。您可以在此 Op 降低 PR中找到一個最小的 Op 降低範例。您還可以在此反向降低 PR中找到一個稍微複雜的反向降低範例。

對於 RegisterXLA.cpp 中的某些運算符,我們具有 out=inplace 運算符的自動產生包裝函式實作。在這種情況下,我們只需要降低 vanilla op。一個範例是 lerp 運算符,它在 native_functions.yaml 中有 6 個變體,它們是

- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor

並且將產生函數原型

at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);

如果在我們將所有變體新增至 xla_native_functions.yaml,則在 XLANativeFunctions.h 中。但是,如果我們僅降低 lerp.Scalarlerp.Tensor 並檢查 RegisterXLA.cpp,我們將看到

namespace {

at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
    // No device check


  // DeviceGuard omitted
  return torch_xla::lerp(self, end, weight);
}

} // anonymous namespace

at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
  auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
  at::_copy_from(wrapper_Scalar_lerp__tmp, self);
  return self;
}

...
  m.impl("lerp_.Scalar",
  TORCH_FN(wrapper_Scalar_lerp_));

Codegen 將自動產生使用我們的 lerp.Scalar 實作的 lerp_.Scalarlerp.Scalar_out 的降低,而無需我們提供明確的降低。

一般而言,如果 pytorch 核心中有一個運算符同時具有 out-of-place 和 out= 變體,最好為 out-of-place 變體編寫降低,因為您將免費獲得程式碼產生的 out= 降低。

對於每個節點,我們都需要傳遞一個 ir::OpKind。這是一個範例 (範例)。您可以在interned_strings.h中找到 OpKind 定義。如果 aten 符號遺失,您可以提交類似此 PR的 PR。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得解答

查看資源