Torch 函式庫 API¶
PyTorch C++ API 提供擴展 PyTorch 核心運算元函式庫的功能,包括使用者定義的運算元和資料類型。 使用 Torch 函式庫 API 實作的擴充功能可在 PyTorch eager API 和 TorchScript 中使用。
如需函式庫 API 的教學式簡介,請參閱使用自訂 C++ 運算元擴展 TorchScript教學課程。
巨集¶
-
TORCH_LIBRARY(ns, m)
用於定義在靜態初始化時運行的函數的巨集,以在命名空間
ns
中定義運算元函式庫(必須是有效的 C++ 識別碼,沒有引號)。當您想要定義 PyTorch 中還不存在的一組新的自訂運算元時,請使用此巨集。
使用範例
TORCH_LIBRARY(myops, m) { // m is a torch::Library; methods on it will define // operators in the myops namespace m.def("add", add_impl); }
m
引數繫結到 torch::Library,用於註冊運算元。 對於任何給定的命名空間,只能有一個 TORCH_LIBRARY()。
-
TORCH_LIBRARY_IMPL(ns, k, m)
用於定義在靜態初始化時運行的函數的巨集,以在命名空間
ns
中定義分派鍵k
(必須是 c10::DispatchKey 的未限定列舉成員)的運算元覆寫(必須是有效的 C++ 識別碼,沒有引號)。當您想要在新分派鍵上實作一組預先存在的自訂運算元時,請使用此巨集(例如,您想要提供現有運算元的 CUDA 實作)。 一種常見的用法模式是使用 TORCH_LIBRARY() 來定義您要定義的所有新運算元的綱要,然後使用幾個 TORCH_LIBRARY_IMPL() 區塊來提供 CPU、CUDA 和 Autograd 運算元的實作。
在某些情況下,您需要定義適用於所有命名空間(而不僅僅是一個命名空間)的內容(通常是備用方法)。 在這種情況下,請使用保留的命名空間 _,例如:
TORCH_LIBRARY_IMPL(_, XLA, m) { m.fallback(xla_fallback); }
使用範例
TORCH_LIBRARY_IMPL(myops, CPU, m) { // m is a torch::Library; methods on it will define // CPU implementations of operators in the myops namespace. // It is NOT valid to call torch::Library::def() // in this context. m.impl("add", add_cpu_impl); }
如果
add_cpu_impl
是重載函數,請使用static_cast
來指定您想要的重載(透過提供完整類型)。
類別¶
-
class Library
此物件提供 API 以定義運算子,並在分發鍵 (dispatch key) 上提供實作。
通常,torch::Library 不會直接配置;而是由 TORCH_LIBRARY() 或 TORCH_LIBRARY_IMPL() 巨集建立。
torch::Library 上的大多數方法都會傳回對自身的參考,以支援方法鏈。
// Examples: TORCH_LIBRARY(torchvision, m) { // m is a torch::Library m.def("roi_align", ...); ... } TORCH_LIBRARY_IMPL(aten, XLA, m) { // m is a torch::Library m.impl("add", ...); ... }
Public Functions
-
template<typename Schema>
inline Library &def(Schema &&raw_schema, const std::vector<at::Tag> &tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & 宣告具有綱要 (schema) 的運算子,但不為其提供任何實作。
您應該使用 impl() 方法來提供實作。所有模板引數都會被推斷出來。
// Example: TORCH_LIBRARY(myops, m) { m.def("add(Tensor self, Tensor other) -> Tensor"); }
- 參數
raw_schema – 要定義的運算子的綱要。 通常,這是一個
const char*
字串文字,但此處接受 torch::schema() 接受的任何類型。
-
inline Library &set_python_module(const char *pymodule, const char *context = "")
宣告對於隨後定義 (def) 的所有運算子,其虛擬實作 (fake impl) 可以在給定的 Python 模組 (pymodule) 中找到。
如果找不到虛擬實作,這會註冊一些說明文字。
Args
pymodule: python 模組
context: 我們可能會將此包含在錯誤訊息中。
-
inline Library &impl_abstract_pystub(const char *pymodule, const char *context = "")
已棄用;請改用 set_python_module。
-
template<typename NameOrSchema, typename Func>
inline Library &def(NameOrSchema &&raw_name_or_schema, Func &&raw_f, const std::vector<at::Tag> &tags = {}) & 為一個 schema 定義運算子,然後為它註冊一個實作。
如果您不打算使用 dispatcher 來組織您的運算子實作,這通常是您會使用的。它大致相當於呼叫 def() 然後呼叫 impl(),但如果您省略了運算子的 schema,我們會從您的 C++ 函數類型推斷它。所有模板引數都會被推斷出來。
// Example: TORCH_LIBRARY(myops, m) { m.def("add", add_fn); }
- 參數
raw_name_or_schema – 要定義的運算子的 schema,或者如果 schema 要從
raw_f
推斷,則僅為運算子的名稱。通常是一個const char*
字串。raw_f – 實作此運算子的 C++ 函數。此處接受 torch::CppFunction 的任何有效建構函式;通常您提供函數指標或 lambda。
-
template<typename Name, typename Func>
inline Library &impl(Name name, Func &&raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & 為運算子註冊一個實作。
您可以為單一運算子,在不同的 dispatch keys 上註冊多個實作 (請參閱 torch::dispatch())。實作必須具有對應的宣告 (來自 def()),否則它們將無效。如果您打算註冊多個實作,請在您 def() 運算子時,不要提供函數實作。
// Example: TORCH_LIBRARY_IMPL(myops, CUDA, m) { m.impl("add", add_cuda); }
- 參數
name – 要實作的運算子的名稱。請勿在此處提供 schema。
raw_f – 實作此運算子的 C++ 函數。此處接受 torch::CppFunction 的任何有效建構函式;通常您提供函數指標或 lambda。
-
template<typename Func>
inline Library &fallback(Func &&raw_f) & 為所有運算子註冊一個 fallback 實作,如果沒有可用的特定運算子實作,將會使用此實作。
必須有與 fallback 關聯的 DispatchKey;例如,僅從具有命名空間
_
的 TORCH_LIBRARY_IMPL() 呼叫此函數。// Example: TORCH_LIBRARY_IMPL(_, AutogradXLA, m) { // If there is not a kernel explicitly registered // for AutogradXLA, fallthrough to the next // available kernel m.fallback(torch::CppFunction::makeFallthrough()); } // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp // for a full example of boxed fallback
- 參數
raw_f – 實作 fallback 的函數。Unboxed 函數通常無法用作 fallback 函數,因為 fallback 函數必須適用於每個運算子 (即使它們具有不同的類型簽名)。典型的引數是 CppFunction::makeFallthrough() 或 CppFunction::makeFromBoxedFunction()
-
template<typename Schema>
-
class CppFunction
表示實作運算子的 C++ 函數。
大多數使用者不會直接與此類別互動,除非透過錯誤訊息:此函數定義的建構函式定義了您可以透過介面綁定的可允許「函數」類型的集合。
此類別會消除傳入函數的類型,但會透過函數的推斷 schema 持久記錄該類型。
Public Functions
-
template<typename Func>
inline explicit CppFunction(Func *f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr) 此重載接受函式指標,例如
CppFunction(&add_impl)
-
template<typename FuncPtr>
inline explicit CppFunction(FuncPtr f, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr) 此重載接受編譯時期的函式指標,例如
CppFunction(TORCH_FN(add_impl))
-
template<typename Lambda>
inline explicit CppFunction(Lambda &&f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr) 此重載接受 lambda,例如
CppFunction([](const Tensor& self) { ...
})
公共靜態函式
-
static inline CppFunction makeFallthrough()
此函式會建立一個 fallthrough 函式。
Fallthrough 函式會立即重新分派到下一個可用的分派金鑰,但其實作效率比以相同方式手動編寫的函式更高。
-
template<c10::BoxedKernel::BoxedKernelFunction *func>
static inline CppFunction makeFromBoxedFunction() 從具有簽章
void(const OperatorHandle&, Stack*)
的 boxed 核心函式建立函式;也就是說,它們會以 boxed 呼叫慣例接收引數堆疊,而不是以原生 C++ 呼叫慣例接收。Boxed 函式通常僅用於透過 torch::Library::fallback() 註冊後端回退。
-
template<class KernelFunctor>
static inline CppFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) 從 boxed kernel functor 建立一個函式,該 functor 定義了
operator()(const OperatorHandle&, DispatchKeySet, Stack*)
(從 boxed 呼叫慣例接收引數) 並繼承自c10::OperatorKernel
。與 makeFromBoxedFunction 不同,以這種方式註冊的函式還可以攜帶由 functor 管理的額外狀態;如果您正在編寫一個到其他實現(例如,一個 Python callable)的適配器,該適配器與已註冊的 kernel 動態關聯,則此功能非常有用。
-
template<typename FuncPtr, std::enable_if_t<c10::guts::is_function_type<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr *f) 從一個 unboxed kernel 函式建立一個函式。
這通常用於註冊常見的運算符。
-
template<typename FuncPtr, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr f) 從編譯時期的 unboxed kernel 函式指標建立一個函式。
這通常用於註冊常見的運算符。編譯時間函式指標可用於允許編譯器優化(例如,內聯)對它的呼叫。
-
template<typename Func>
函式¶
-
template<typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func &&raw_f)¶ 建立一個與特定調度金鑰相關聯的 torch::CppFunction。
除非調度器確定這個特定的 c10::DispatchKey 是應該調度到的金鑰,否則不會調用帶有 c10::DispatchKey 標記的 torch::CppFunctions。
通常不直接使用此函式,而是建議使用 TORCH_LIBRARY_IMPL(),它會隱式地為其主體內的所有註冊呼叫設定 c10::DispatchKey。
-
template<typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func &&raw_f)¶ 方便的 dispatch() 多載,接受 c10::DeviceType。
-
inline c10::FunctionSchema schema(const char *str, c10::AliasAnalysisKind k, bool allow_typevars = false)¶
從字串建構 c10::FunctionSchema,並明確指定 c10::AliasAnalysisKind。
通常,schema 只是以字串形式傳入,但如果您需要指定自定義的別名分析,則可以將字串替換為對此函式的呼叫。
// Default alias analysis (FROM_SCHEMA) m.def("def3(Tensor self) -> Tensor"); // Pure function alias analysis m.def(torch::schema("def3(Tensor self) -> Tensor", c10::AliasAnalysisKind::PURE_FUNCTION));
-
inline c10::FunctionSchema schema(const char *s, bool allow_typevars = false)¶
函數 schema 可以直接從字串文字建構。