自訂 C++ 和 CUDA 運算子¶
建立於:2024 年 6 月 18 日 | 最後更新:2025 年 1 月 28 日 | 最後驗證:2024 年 11 月 05 日
作者: Richard Zou
如何將以 C++/CUDA 撰寫的自訂運算子與 PyTorch 整合
如何使用
torch.library.opcheck
測試自訂運算子
PyTorch 2.4 或更高版本
C++ 和 CUDA 程式設計的基本理解
注意
本教學也適用於 AMD ROCm,無需額外修改。
PyTorch 提供一個大型的運算子函式庫,可處理 Tensor (例如 torch.add、torch.sum 等)。 但是,您可能希望將新的自訂運算子引入 PyTorch。 本教學示範了撰寫以 C++/CUDA 撰寫的自訂運算子的最佳途徑。
在本教學中,我們將示範如何撰寫一個融合的乘加 C++ 和 CUDA 運算子,該運算子與 PyTorch 子系統組合。 運算的語義如下
def mymuladd(a: Tensor, b: Tensor, c: float):
return a * b + c
您可以在這裡找到本教學的端到端工作範例。
設定建置系統¶
如果您正在開發自訂 C++/CUDA 程式碼,則必須對其進行編譯。 請注意,如果您正在與已經綁定到預編譯 C++/CUDA 程式碼的 Python 函式庫進行介面,您可以考慮編寫自訂 Python 運算子 (自訂 Python 運算子)。
使用 torch.utils.cpp_extension 編譯自訂 C++/CUDA 程式碼以與 PyTorch 一起使用。 C++ 擴充功能可以通過 setuptools “提前” 建置,也可以通過 load_inline “即時” 建置;我們將著重於 “提前” 的方式。
使用 cpp_extension
就像編寫以下 setup.py
一樣簡單
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
"extension_cpp",
["muladd.cpp"],
# define Py_LIMITED_API with min version 3.9 to expose only the stable
# limited API subset from Python.h
extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]},
py_limited_api=True)], # Build 1 wheel across multiple Python versions
cmdclass={'build_ext': cpp_extension.BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version
)
如果您需要編譯 CUDA 程式碼 (例如,.cu
檔案),則改為使用 torch.utils.cpp_extension.CUDAExtension。 請參閱 extension-cpp 以獲取如何設定的範例。
上面的範例表示我們所說的與 CPython 無關的 wheel,表示我們正在建置一個可以在多個 CPython 版本上執行的單個 wheel (類似於純 Python 套件)。 在最小化您的自訂函式庫需要支援和發布的 wheel 數量方面,CPython 不可知論是理想的。 我們希望支援的最低版本是 3.9,因為它是目前支援的最舊版本,因此我們在整個設定程式碼中使用相應的十六進制程式碼和規範符。 我們建議在與您想要支援的最低 CPython 版本相同的環境中建置擴充功能,以最大程度地減少未知行為,因此,在這裡,我們在 CPython 3.9 環境中建置擴充功能。 建置完成後,此單個 wheel 將可在任何 CPython 環境 3.9+ 中執行。 為了實現這一點,有三個關鍵行需要注意。
第一個是在 extra_compile_args
中指定 Py_LIMITED_API
到您想要支援的最低 CPython 版本
extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]},
定義 Py_LIMITED_API
標誌有助於驗證擴充功能實際上僅使用 CPython Stable Limited API,這是建置與 CPython 無關的 wheel 的要求。 如果不滿足此要求,則可以建置一個看起來與 CPython 無關的 wheel,但在另一個 CPython 環境中會崩潰,或者更糟的是,會默默地不正確。 請注意避免使用不穩定的 CPython API,例如來自 libtorch_python 的 API (尤其是 pytorch/python 綁定),並且僅使用來自 libtorch 的 API (ATen 物件、運算子和分派器)。 我們強烈建議定義 Py_LIMITED_API
標誌,以幫助確定擴充功能是否合規且作為與 CPython 無關的 wheel 是安全的。 請注意,定義此標誌並不能完全保證建置的 wheel 與 CPython 無關,但總比混亂的情況要好。 Python 文件中提到了幾個注意事項,您應該自己測試和驗證 wheel 對於相關 CPython 版本而言是否真正不可知。
指定 py_limited_api
的第二行和第三行通知 setuptools 您打算建置一個與 CPython 無關的 wheel,並將相應地影響 wheel 的命名
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
...,
py_limited_api=True)], # Build 1 wheel across multiple Python versions
...,
options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version
)
必須指定 py_limited_api=True
作為 CppExtension/ CUDAExtension 的引數,同時也要作為 "bdist_wheel"
指令的選項,並指定最小支援的 CPython 版本 (在此範例中為 3.9)。因此,我們教學中的 setup
會建置一個名稱正確的 wheel 檔案,它可以安裝在多個 CPython 版本 >=3.9
上。
如果您的擴充功能使用了 stable limited set 之外的 CPython API,那麼您就無法建置 CPython 不可知論的 wheel 檔案!您應該為每個 CPython 版本建置一個 wheel 檔案,如下所示:
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name="extension_cpp",
ext_modules=[
cpp_extension.CppExtension(
"extension_cpp",
["muladd.cpp"])],
cmdclass={'build_ext': cpp_extension.BuildExtension},
)
定義自定義運算元並新增後端實作¶
首先,讓我們編寫一個 C++ 函式來計算 mymuladd
at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
}
return result;
}
為了從 PyTorch 的 Python 前端使用它,我們需要使用 TORCH_LIBRARY
API 將其註冊為 PyTorch 運算元。 這將自動將運算元綁定到 Python。
運算元註冊是一個兩步驟的過程
定義運算元 - 此步驟確保 PyTorch 知道新的運算元。
註冊後端實作 - 在此步驟中,各種後端的實作(例如 CPU 和 CUDA)與運算元相關聯。
定義運算元¶
要定義運算元,請按照以下步驟操作
為運算元選擇一個命名空間。我們建議命名空間是您的頂層專案的名稱;在我們的教學中,我們將使用 “extension_cpp”。
提供一個 schema 字串,用於指定運算元的輸入/輸出類型,以及是否會改變輸入張量。除了 Tensor 和 float 之外,我們還支援更多類型;有關更多詳細資訊,請參閱 自定義運算元手冊。
如果您正在編寫可以改變其輸入張量的運算元,請參閱此處 (建立可變運算元) 了解如何指定。
TORCH_LIBRARY(extension_cpp, m) {
// Note that "float" in the schema corresponds to the C++ double type
// and the Python float type.
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
}
這使得可以透過 torch.ops.extension_cpp.mymuladd
從 Python 使用該運算元。
為運算元註冊後端實作¶
使用 TORCH_LIBRARY_IMPL
為運算元註冊後端實作。
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
}
如果您還有 myaddmul
的 CUDA 實作,您可以在單獨的 TORCH_LIBRARY_IMPL
區塊中註冊它
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", &mymuladd_cuda);
}
為運算元新增 torch.compile
支援¶
若要為運算元新增 torch.compile
支援,我們必須新增一個 FakeTensor 核心(也稱為「meta 核心」或「abstract impl」)。 FakeTensor 是具有元數據(例如形狀、dtype、設備)但沒有數據的 Tensor:運算元的 FakeTensor 核心指定如何根據輸入張量的元數據計算輸出張量的元數據。 FakeTensor 核心應返回您選擇的,具有正確 Tensor 元數據(形狀/步幅/dtype
/設備)的虛擬 Tensor。
我們建議透過 torch.library.register_fake
API 從 Python 執行此操作,儘管也可以從 C++ 執行此操作(有關更多詳細資訊,請參閱 自定義運算元手冊)。
# Important: the C++ custom operator definitions should be loaded first
# before calling ``torch.library`` APIs that add registrations for the
# C++ custom operator(s). The following import loads our
# C++ custom operator definitions.
# Note that if you are striving for Python agnosticism, you should use
# the ``load_library(...)`` API call instead. See the next section for
# more details.
from . import _C
@torch.library.register_fake("extension_cpp::mymuladd")
def _(a, b, c):
torch._check(a.shape == b.shape)
torch._check(a.dtype == torch.float)
torch._check(b.dtype == torch.float)
torch._check(a.device == b.device)
return torch.empty_like(a)
設定混合 Python/C++ 註冊¶
在本教學中,我們在 C++ 中定義了一個自定義運算元,在 C++ 中新增了 CPU/CUDA 實作,並在 Python 中新增了 FakeTensor
核心和向後公式。 這些註冊的載入(或匯入)順序很重要(以錯誤的順序匯入將導致錯誤)。
若要將自定義運算元與混合 Python/C++ 註冊一起使用,我們必須首先載入包含自定義運算元定義的 C++ 庫,然後呼叫 torch.library
註冊 API。 可以透過三種方式實現此目的
載入包含自定義運算元定義的 C++ 庫的第一種方法是為 _C 定義一個虛擬 Python 模組。 然後,在 Python 中,當您使用
import _C
匯入模組時,將載入與擴充功能對應的.so
檔案,並且將執行TORCH_LIBRARY
和TORCH_LIBRARY_IMPL
靜態初始化器。 可以使用類似下面的PYBIND11_MODULE
建立虛擬 Python 模組,但您會注意到這無法使用Py_LIMITED_API
編譯,因為pybind11
並不保證只使用 stable limited CPython API! 使用下面的程式碼,您很遺憾無法為您的擴充功能建置 CPython 不可知論的 wheel 檔案! (伏筆:我想知道第二種方法是什麼 ;) )。
// in, say, not_agnostic/csrc/extension_BAD.cpp
#include <pybind11/pybind11.h>
PYBIND11_MODULE("_C", m) {}
# in, say, extension/__init__.py
from . import _C
在本教學中,因為我們重視能夠在多個 CPython 版本之間建置單個 wheel 檔案,所以我們將使用 stable API 呼叫替換不穩定的
PYBIND11
呼叫。 下面的程式碼使用-DPy_LIMITED_API=0x03090000
編譯,並成功為我們的_C
擴充功能建立虛擬 Python 模組,以便可以從 Python 匯入。 有關更多詳細資訊,請參閱 extension_cpp/__init__.py 和 extension_cpp/csrc/muladd.cpp。
#include <Python.h>
extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so consisting of this file
in this extension, so that the TORCH_LIBRARY static initializers
below are run. */
PyObject* PyInit__C(void)
{
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
}
# in, say, extension/__init__.py
from . import _C
如果您想完全避免在 C++ 自定義運算元中使用
Python.h
,您可以使用 Python 中的torch.ops.load_library("/path/to/library.so")
載入從擴充功能編譯的.so
檔案。 請注意,使用此方法,不會為擴充功能建立_C
Python 模組,因此您無法從 Python 呼叫import _C
。torch.ops.load_library("/path/to/library.so")
將執行此操作,而不是依賴匯入語句來觸發自定義運算元的註冊。 然後,挑戰轉向了解.so
檔案的位置,以便您可以載入它們,這並非總是那麼容易
import torch
from pathlib import Path
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert (
len(so_files) == 1
), f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops
為運算元新增訓練 (autograd) 支援¶
使用 torch.library.register_autograd
來為運算子新增訓練支援。 相比於直接使用 Python 的 torch.autograd.Function
或 C++ 的 torch::autograd::Function
,此方法更佳;您必須以非常特定的方式使用它們,才能避免無聲的錯誤(詳情請參閱自定義運算子手冊)。
def _backward(ctx, grad):
a, b = ctx.saved_tensors
grad_a, grad_b = None, None
if ctx.needs_input_grad[0]:
grad_a = grad * b
if ctx.needs_input_grad[1]:
grad_b = grad * a
return grad_a, grad_b, None
def _setup_context(ctx, inputs, output):
a, b, c = inputs
saved_a, saved_b = None, None
if ctx.needs_input_grad[0]:
saved_b = b
if ctx.needs_input_grad[1]:
saved_a = a
ctx.save_for_backward(saved_a, saved_b)
# This code adds training support for the operator. You must provide us
# the backward formula for the operator and a `setup_context` function
# to save values to be used in the backward.
torch.library.register_autograd(
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
請注意,backward 必須是 PyTorch 可理解的運算子的組合。 如果您希望在 backwards pass 中使用另一個自定義的 C++ 或 CUDA 核心,則必須將其包裝到自定義運算子中。
如果我們有自己的自定義 mymul
核心,我們需要將其包裝到自定義運算子中,然後從 backward 呼叫它。
// New! a mymul_cpu kernel
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(a.device().type() == at::DeviceType::CPU);
TORCH_CHECK(b.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = a_ptr[i] * b_ptr[i];
}
return result;
}
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
// New! defining the mymul operator
m.def("mymul(Tensor a, Tensor b) -> Tensor");
}
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
// New! registering the cpu kernel for the mymul operator
m.impl("mymul", &mymul_cpu);
}
def _backward(ctx, grad):
a, b = ctx.saved_tensors
grad_a, grad_b = None, None
if ctx.needs_input_grad[0]:
grad_a = torch.ops.extension_cpp.mymul.default(grad, b)
if ctx.needs_input_grad[1]:
grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
return grad_a, grad_b, None
def _setup_context(ctx, inputs, output):
a, b, c = inputs
saved_a, saved_b = None, None
if ctx.needs_input_grad[0]:
saved_b = b
if ctx.needs_input_grad[1]:
saved_a = a
ctx.save_for_backward(saved_a, saved_b)
# This code adds training support for the operator. You must provide us
# the backward formula for the operator and a `setup_context` function
# to save values to be used in the backward.
torch.library.register_autograd(
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
測試運算子¶
使用 torch.library.opcheck
來測試自定義運算子是否已正確註冊。 請注意,此函式不測試梯度在數學上是否正確 - 計劃編寫單獨的測試,手動測試或使用 torch.autograd.gradcheck
。
def sample_inputs(device, *, requires_grad=False):
def make_tensor(*size):
return torch.randn(size, device=device, requires_grad=requires_grad)
def make_nondiff_tensor(*size):
return torch.randn(size, device=device, requires_grad=False)
return [
[make_tensor(3), make_tensor(3), 1],
[make_tensor(20), make_tensor(20), 3.14],
[make_tensor(20), make_nondiff_tensor(20), -123],
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
]
def reference_muladd(a, b, c):
return a * b + c
samples = sample_inputs(device, requires_grad=True)
samples.extend(sample_inputs(device, requires_grad=False))
for args in samples:
# Correctness test
result = torch.ops.extension_cpp.mymuladd(*args)
expected = reference_muladd(*args)
torch.testing.assert_close(result, expected)
# Use opcheck to check for incorrect usage of operator registration APIs
torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args)
建立可變運算子¶
您可能希望撰寫一個會改變其輸入的自定義運算子。 使用 Tensor(a!)
來指定 schema 中的每個可變 Tensor; 否則,會發生未定義的行為。 如果有多個已改變的 Tensor,請對每個可變 Tensor 使用不同的名稱 (例如, Tensor(a!)
、 Tensor(b!)
、 Tensor(c!)
)。
讓我們撰寫一個 myadd_out(a, b, out)
運算子,它將 a+b
的內容寫入 out
。
// An example of an operator that mutates one of its inputs.
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(b.sizes() == out.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(out.dtype() == at::kFloat);
TORCH_CHECK(out.is_contiguous());
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
for (int64_t i = 0; i < out.numel(); i++) {
result_ptr[i] = a_ptr[i] + b_ptr[i];
}
}
在定義運算子時,我們必須在 schema 中指定它會改變 out Tensor
TORCH_LIBRARY(extension_cpp, m) {
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
m.def("mymul(Tensor a, Tensor b) -> Tensor");
// New!
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
}
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
m.impl("mymuladd", &mymuladd_cpu);
m.impl("mymul", &mymul_cpu);
// New!
m.impl("myadd_out", &myadd_out_cpu);
}
注意
不要將任何已改變的 Tensor 作為運算子的輸出回傳,因為這會導致與 PyTorch 子系統 (例如 torch.compile
) 不相容。