自定義 C++ 和 CUDA 擴展¶
建立於:2018 年 4 月 26 日 | 最後更新:2024 年 7 月 22 日 | 最後驗證:2024 年 11 月 05 日
警告
本教學已在 PyTorch 2.4 中棄用。請參閱 PyTorch 自定義運算符,以取得關於使用自定義 C++/CUDA 擴展擴展 PyTorch 的最新指南。
PyTorch 提供了大量與神經網路、任意張量代數、資料處理和其他用途相關的運算。但是,您可能仍然需要更自定義的運算。例如,您可能想要使用您在論文中找到的一種新型啟動函數,或者實作您作為研究一部分開發的運算。
在 PyTorch 中整合此類自定義運算的最簡單方法是透過擴展 Function
和 Module
以 Python 撰寫它,如此處所述。這為您提供了自動微分的全部力量(使您免於編寫導數函數)以及 Python 的常用表達能力。但是,有時您的運算最好在 C++ 中實作。例如,您的程式碼可能需要真的很快,因為它在您的模型中被調用得非常頻繁,或者即使對於少數調用來說也非常昂貴。另一個可能的原因是它依賴於其他 C 或 C++ 函式庫或與之互動。為了應對這種情況,PyTorch 提供了一種非常簡單的方式來撰寫自定義的C++ 擴展。
C++ 擴展是我們開發的一種機制,允許使用者(您)建立定義在源程式碼之外的 PyTorch 運算符,也就是說,與 PyTorch 後端分離。這種方法與原生 PyTorch 運算的實作方式不同。C++ 擴展旨在讓您免於將運算與 PyTorch 後端整合相關的許多樣板程式碼,同時為您的基於 PyTorch 的專案提供高度的靈活性。儘管如此,一旦您將您的運算定義為 C++ 擴展,將其轉換為原生 PyTorch 函數主要是一個程式碼組織問題,如果您決定向上游貢獻您的運算,您可以在事後解決這個問題。
動機和範例¶
本節的其餘部分將逐步說明撰寫和使用 C++(和 CUDA)擴展的實際範例。如果您被追趕,或者如果到今天結束您還沒有完成該運算,有人會解僱您,您可以跳過本節,直接前往下一節中的實作細節。
假設您想出了一種新型的遞迴單元,您發現它比現有技術具有卓越的屬性。這種遞迴單元與 LSTM 相似,但不同之處在於它缺乏遺忘閘,並使用指數線性單元 (ELU) 作為其內部啟動函數。因為這個單元永遠不會忘記,我們將稱之為 LLTM,或 Long-Long-Term-Memory 單元。
LLTM 與普通 LSTM 的兩個不同之處非常重要,我們無法為我們的目的配置 PyTorch 的 LSTMCell
,因此我們必須建立一個自定義單元。為此的第一個也是最簡單的方法 - 並且可能在所有情況下都是一個好的第一步 - 是使用 Python 中的普通 PyTorch 實作我們所需的功能。為此,我們需要子類化 torch.nn.Module
並實作 LLTM 的正向傳遞。這看起來像這樣
class LLTM(torch.nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
self.input_features = input_features
self.state_size = state_size
# 3 * state_size for input gate, output gate and candidate cell gate.
# input_features + state_size because we will multiply with [input, h].
self.weights = torch.nn.Parameter(
torch.empty(3 * state_size, input_features + state_size))
self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.state_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, +stdv)
def forward(self, input, state):
old_h, old_cell = state
X = torch.cat([old_h, input], dim=1)
# Compute the input, output and candidate cell gates with one MM.
gate_weights = F.linear(X, self.weights, self.bias)
# Split the combined gate weight matrix into its components.
gates = gate_weights.chunk(3, dim=1)
input_gate = torch.sigmoid(gates[0])
output_gate = torch.sigmoid(gates[1])
# Here we use an ELU instead of the usual tanh.
candidate_cell = F.elu(gates[2])
# Compute the new cell state.
new_cell = old_cell + candidate_cell * input_gate
# Compute the new hidden state and output.
new_h = torch.tanh(new_cell) * output_gate
return new_h, new_cell
然後我們可以像預期那樣使用它
import torch
X = torch.randn(batch_size, input_features)
h = torch.randn(batch_size, state_size)
C = torch.randn(batch_size, state_size)
rnn = LLTM(input_features, state_size)
new_h, new_C = rnn(X, (h, C))
當然,如果可能且合理,您應該使用這種方法來擴展 PyTorch。由於 PyTorch 對 CPU 和 GPU 上的運算都有高度優化的實現,並由諸如 NVIDIA cuDNN、Intel MKL 或 NNPACK 等函式庫提供支援,因此像上面這樣的 PyTorch 程式碼通常速度夠快。然而,我們也可以看出,在某些情況下,還有進一步提升效能的空間。最明顯的原因是 PyTorch 不了解您正在實作的演算法。它只知道您用來組成演算法的個別運算。因此,PyTorch 必須個別執行您的運算,一個接一個。由於每次單獨調用運算的實現(或核心),這可能涉及啟動 CUDA 核心,都會有一定的額外負擔,因此這種額外負擔在許多函數調用中可能會變得非常顯著。此外,執行我們程式碼的 Python 直譯器本身也會減慢我們程式的速度。
因此,加速的明確方法是用 C++(或 CUDA)重寫部分程式碼,並融合特定的運算群組。融合意味著將許多函數的實現合併到一個函數中,這可以從更少的核心啟動以及其他我們可以利用全局資料流的更高可見性執行的優化中獲益。
讓我們看看如何使用 C++ 擴展來實現 LLTM 的融合版本。 我們將首先用純 C++ 編寫它,使用 ATen 函式庫,該函式庫為 PyTorch 的許多後端提供支援,並了解它讓我們多麼輕鬆地轉換我們的 Python 程式碼。然後,我們將通過將模型的某些部分移動到 CUDA 核心來進一步加速,從而受益於 GPU 提供的大規模並行性。
編寫 C++ 擴展¶
C++ 擴展有兩種形式:它們可以使用 setuptools
進行“提前”建置,或通過 torch.utils.cpp_extension.load()
進行“即時”建置。我們將從第一種方法開始,稍後再討論後者。
使用 setuptools
建置¶
對於“提前”形式,我們通過編寫一個 setup.py
腳本來建置我們的 C++ 擴展,該腳本使用 setuptools 來編譯我們的 C++ 程式碼。對於 LLTM,它看起來像這樣簡單
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='lltm_cpp',
ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})
在此程式碼中,CppExtension
是 setuptools.Extension
的一個方便的封裝器,它傳遞正確的包含路徑並將擴展的語言設定為 C++。等效的原始 setuptools
程式碼將僅僅是
Extension(
name='lltm_cpp',
sources=['lltm.cpp'],
include_dirs=cpp_extension.include_paths(),
language='c++')
BuildExtension
執行許多必需的配置步驟和檢查,並且在混合 C++/CUDA 擴展的情況下還管理混合編譯。這就是我們現在真正需要了解的關於建置 C++ 擴展的所有內容!現在讓我們看看我們的 C++ 擴展的實現,它進入 lltm.cpp
。
編寫 C++ Op¶
讓我們開始在 C++ 中實現 LLTM!我們在反向傳遞中需要的一個函數是 sigmoid 的導數。這是一小段程式碼,可以討論在編寫 C++ 擴展時可供我們使用的總體環境
#include <torch/extension.h>
#include <iostream>
torch::Tensor d_sigmoid(torch::Tensor z) {
auto s = torch::sigmoid(z);
return (1 - s) * s;
}
<torch/extension.h>
是一個一站式標頭,用於包含所有必要的 PyTorch 位元,以編寫 C++ 擴展。它包括
ATen 函式庫,它是我們進行張量計算的主要 API,
pybind11,這是我們如何為 C++ 程式碼建立 Python 綁定的方式,
管理 ATen 和 pybind11 之間交互細節的標頭。
d_sigmoid()
的實現展示了如何使用 ATen API。 PyTorch 的張量和變數介面是從 ATen 函式庫自動生成的,因此我們可以或多或少地將我們的 Python 實現 1:1 轉換為 C++。我們所有計算的主要資料類型將是 torch::Tensor
。可以在這裡檢查其完整的 API。另請注意,我們可以包含 <iostream>
或任何其他 C 或 C++ 標頭 – 我們擁有 C++11 的全部能力。
請注意,CUDA-11.5 nvcc 在 Windows 上解析 torch/extension.h 時會遇到內部編譯器錯誤。要解決此問題,請將 python 綁定邏輯移至純 C++ 檔案。範例用法
#include <ATen/ATen.h>
at::Tensor SigmoidAlphaBlendForwardCuda(....)
取代為
#include <torch/extension.h>
torch::Tensor SigmoidAlphaBlendForwardCuda(...)
目前 nvcc 錯誤的公開問題 這裡。完整的解決方法程式碼範例 這裡。
正向傳遞¶
接下來,我們可以將整個正向傳遞移植到 C++
#include <vector>
std::vector<at::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);
auto input_gate = torch::sigmoid(gates[0]);
auto output_gate = torch::sigmoid(gates[1]);
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = torch::tanh(new_cell) * output_gate;
return {new_h,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights};
}
反向傳遞¶
C++ 擴展 API 目前沒有提供自動為我們產生反向函數的方法。因此,我們還必須實現 LLTM 的反向傳遞,該反向傳遞計算損失相對於正向傳遞的每個輸入的導數。最終,我們將把正向和反向函數都放入 torch.autograd.Function
中,以建立一個漂亮的 Python 綁定。反向函數稍微複雜一些,因此我們不會深入研究程式碼(如果您有興趣,Alex Graves 的論文是一個很好的讀物,可以獲取更多資訊)
// tanh'(z) = 1 - tanh^2(z)
torch::Tensor d_tanh(torch::Tensor z) {
return 1 - z.tanh().pow(2);
}
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_output_gate = torch::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;
auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);
auto d_gates =
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gates.mm(weights);
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}
綁定到 Python¶
一旦你用 C++ 和 ATen 編寫好你的操作,你就可以使用 pybind11 以非常簡單的方式將你的 C++ 函式或類別綁定到 Python 中。關於 PyTorch C++ 擴展的這部分,你所遇到的問題或疑慮,很大程度上可以透過 pybind11 文件 獲得解答。
對於我們的擴展,必要的綁定程式碼僅需四行
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward");
m.def("backward", &lltm_backward, "LLTM backward");
}
這裡需要注意的是巨集 TORCH_EXTENSION_NAME
。torch 擴展構建會將它定義為你在 setup.py
腳本中給予你的擴展名稱。在這種情況下,TORCH_EXTENSION_NAME
的值將會是 "lltm_cpp"。 這樣做是為了避免在兩個地方(構建腳本和你的 C++ 程式碼)維護擴展的名稱,因為兩者之間的不匹配可能會導致令人討厭且難以追蹤的問題。
使用你的擴展¶
我們現在可以開始在 PyTorch 中導入我們的擴展。此時,你的目錄結構可能如下所示
pytorch/
lltm-extension/
lltm.cpp
setup.py
現在,執行 python setup.py install
來構建和安裝你的擴展。 它看起來應該像這樣
running install
running bdist_egg
running egg_info
creating lltm_cpp.egg-info
writing lltm_cpp.egg-info/PKG-INFO
writing dependency_links to lltm_cpp.egg-info/dependency_links.txt
writing top-level names to lltm_cpp.egg-info/top_level.txt
writing manifest file 'lltm_cpp.egg-info/SOURCES.txt'
reading manifest file 'lltm_cpp.egg-info/SOURCES.txt'
writing manifest file 'lltm_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'lltm_cpp' extension
creating build
creating build/temp.linux-x86_64-3.7
gcc -pthread -B ~/local/miniconda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I~/local/miniconda/lib/python3.7/site-packages/torch/include -I~/local/miniconda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I~/local/miniconda/lib/python3.7/site-packages/torch/include/TH -I~/local/miniconda/lib/python3.7/site-packages/torch/include/THC -I~/local/miniconda/include/python3.7m -c lltm.cpp -o build/temp.linux-x86_64-3.7/lltm.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=lltm_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++11
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.7
g++ -pthread -shared -B ~/local/miniconda/compiler_compat -L~/local/miniconda/lib -Wl,-rpath=~/local/miniconda/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.7/lltm.o -o build/lib.linux-x86_64-3.7/lltm_cpp.cpython-37m-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.7/lltm_cpp.cpython-37m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for lltm_cpp.cpython-37m-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/lltm_cpp.py to lltm_cpp.cpython-37.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.lltm_cpp.cpython-37: module references __file__
creating 'dist/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
removing '~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg' (and everything under it)
creating ~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
Extracting lltm_cpp-0.0.0-py3.7-linux-x86_64.egg to ~/local/miniconda/lib/python3.7/site-packages
lltm-cpp 0.0.0 is already the active version in easy-install.pth
Installed ~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
Processing dependencies for lltm-cpp==0.0.0
Finished processing dependencies for lltm-cpp==0.0.0
關於編譯器的簡短說明:由於 ABI 版本問題,你用來構建 C++ 擴展的編譯器必須與 PyTorch 构建時使用的編譯器 ABI 相容。 實際上,這意味著你必須在 Linux 上使用 GCC 4.9 及以上版本。 對於 Ubuntu 16.04 和其他較新的 Linux 發行版,這應該已經是預設的編譯器。 在 MacOS 上,你必須使用 clang(它沒有任何 ABI 版本問題)。 在最壞的情況下,你可以使用你的編譯器從原始碼構建 PyTorch,然後使用相同的編譯器構建擴展。
一旦你的擴展被構建,你可以簡單地在 Python 中導入它,使用你在 setup.py
腳本中指定的名稱。 請務必先 import torch
,因為這會解析一些動態連結器必須看到的符號
In [1]: import torch
In [2]: import lltm_cpp
In [3]: lltm_cpp.forward
Out[3]: <function lltm.PyCapsule.forward>
如果我們在函數或模組上呼叫 help()
,我們可以發現它的簽名與我們的 C++ 程式碼相符
In[4] help(lltm_cpp.forward)
forward(...) method of builtins.PyCapsule instance
forward(arg0: torch::Tensor, arg1: torch::Tensor, arg2: torch::Tensor, arg3: torch::Tensor, arg4: torch::Tensor) -> List[torch::Tensor]
LLTM forward
既然我們現在可以從 Python 呼叫我們的 C++ 函數,我們可以透過 torch.autograd.Function
和 torch.nn.Module
來包裝它們,使它們成為 PyTorch 的一等公民
import math
import torch
# Our module!
import lltm_cpp
class LLTMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)
return new_h, new_cell
@staticmethod
def backward(ctx, grad_h, grad_cell):
outputs = lltm_cpp.backward(
grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors)
d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
return d_input, d_weights, d_bias, d_old_h, d_old_cell
class LLTM(torch.nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
self.input_features = input_features
self.state_size = state_size
self.weights = torch.nn.Parameter(
torch.empty(3 * state_size, input_features + state_size))
self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.state_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, +stdv)
def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)
效能比較¶
既然我們能夠從 PyTorch 使用和呼叫我們的 C++ 程式碼,我們可以運行一個小型基準測試,看看我們從用 C++ 重寫我們的 op 中獲得了多少效能提升。 我們將多次運行 LLTM 的前向和後向,並測量持續時間
import time
import torch
batch_size = 16
input_features = 32
state_size = 128
X = torch.randn(batch_size, input_features)
h = torch.randn(batch_size, state_size)
C = torch.randn(batch_size, state_size)
rnn = LLTM(input_features, state_size)
forward = 0
backward = 0
for _ in range(100000):
start = time.time()
new_h, new_C = rnn(X, (h, C))
forward += time.time() - start
start = time.time()
(new_h.sum() + new_C.sum()).backward()
backward += time.time() - start
print('Forward: {:.3f} s | Backward {:.3f} s'.format(forward, backward))
如果我們使用本文開頭以純 Python 編寫的原始 LLTM 運行此程式碼,我們會得到以下數字(在我的機器上)
Forward: 506.480 us | Backward 444.694 us
以及我們新的 C++ 版本
Forward: 349.335 us | Backward 443.523 us
我們已經可以看到前向函數的顯著加速(超過 30%)。 對於後向函數,可以看到加速,儘管不是主要的加速。 我上面寫的後向傳遞並未經過特別的優化,並且肯定可以改進。 此外,PyTorch 的自動微分引擎可以自動並行化計算圖,可以使用更有效率的整體操作流程,並且也是用 C++ 實現的,因此預計會很快。 儘管如此,這是一個好的開始。
在 GPU 設備上的效能¶
關於 PyTorch 的 ATen 後端的一個奇妙的事實是,它抽象了你正在運行的計算設備。 這意味著我們為 CPU 編寫的相同程式碼 也 可以運行在 GPU 上,並且個別操作將相應地分派到 GPU 優化的實現。 對於某些操作,例如矩陣乘法(如 mm
或 addmm
),這是一個很大的優勢。 讓我們看看透過使用 CUDA 張量運行我們的 C++ 程式碼,我們獲得了多少效能提升。 不需要對我們的實作進行任何更改,我們只需要從 Python 將我們的張量放入 GPU 記憶體中,方法是在建立時添加 device=cuda_device
參數,或在建立後使用 .to(cuda_device)
import torch
assert torch.cuda.is_available()
cuda_device = torch.device("cuda") # device object representing GPU
batch_size = 16
input_features = 32
state_size = 128
# Note the device=cuda_device arguments here
X = torch.randn(batch_size, input_features, device=cuda_device)
h = torch.randn(batch_size, state_size, device=cuda_device)
C = torch.randn(batch_size, state_size, device=cuda_device)
rnn = LLTM(input_features, state_size).to(cuda_device)
forward = 0
backward = 0
for _ in range(100000):
start = time.time()
new_h, new_C = rnn(X, (h, C))
torch.cuda.synchronize()
forward += time.time() - start
start = time.time()
(new_h.sum() + new_C.sum()).backward()
torch.cuda.synchronize()
backward += time.time() - start
print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/1e5, backward * 1e6/1e5))
再次比較我們的普通 PyTorch 程式碼和我們的 C++ 版本,現在兩者都在 CUDA 設備上運行,我們再次看到了效能提升。 對於 Python/PyTorch
Forward: 187.719 us | Backward 410.815 us
和 C++/ATen
Forward: 149.802 us | Backward 393.458 us
與非 CUDA 程式碼相比,這是一個非常好的整體加速。 但是,我們可以透過編寫自定義 CUDA 核心從我們的 C++ 程式碼中提取更多的效能,我們將很快深入研究。 在此之前,讓我們討論另一種構建 C++ 擴展的方式。
JIT 編譯擴展¶
之前,我提到有兩種構建 C++ 擴展的方式:使用 setuptools
或即時 (JIT)。 在介紹了前者之後,讓我們詳細說明後者。 JIT 編譯機制為你提供了一種透過呼叫 PyTorch API 中的一個簡單函數 torch.utils.cpp_extension.load()
,來動態編譯和載入你的擴展的方法。 對於 LLTM,這將看起來像這樣簡單
from torch.utils.cpp_extension import load
lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"])
在這裡,我們向函數提供與 setuptools
相同的信息。 在後台,這將執行以下操作
建立一個臨時目錄
/tmp/torch_extensions/lltm
,將 Ninja 建置文件發送到該臨時目錄中,
將你的原始碼編譯到共享庫中,
將此共享庫導入為 Python 模組。
事實上,如果您傳遞 verbose=True
給 cpp_extension.load()
,您將會收到關於處理過程的通知。
Using /tmp/torch_extensions as PyTorch extensions root...
Emitting ninja build file /tmp/torch_extensions/lltm_cpp/build.ninja...
Building extension module lltm_cpp...
Loading extension module lltm_cpp...
產生的 Python 模組將與 setuptools 產生的一模一樣,但消除了維護單獨的 setup.py
建置檔案的需求。 如果您的設定更複雜,並且您確實需要 setuptools
的完整功能,您可以編寫自己的 setup.py
– 但在許多情況下,這種 JIT 技術就足夠好了。 第一次執行這行程式碼時,由於擴充功能在背景編譯,因此需要一些時間。 由於我們使用 Ninja 建置系統來建置您的原始碼,因此重新編譯是增量的,因此當您第二次執行您的 Python 模組時,重新載入擴充功能的速度很快,並且如果您沒有更改擴充功能的原始檔,則開銷很低。
編寫混合 C++/CUDA 擴充功能¶
為了真正將我們的實作提升到一個新的水平,我們可以手寫一部分的前向和反向傳遞,並使用自定義 CUDA 核心。 對於 LLTM,這有希望特別有效,因為有一系列大量的 pointwise 運算,這些運算都可以在單個 CUDA 核心中融合和並行化。 讓我們看看如何編寫這樣的 CUDA 核心,並使用此擴充機制將其與 PyTorch 整合。
編寫 CUDA 擴充功能的一般策略是首先編寫一個 C++ 檔案,該檔案定義將從 Python 呼叫的函數,並使用 pybind11 將這些函數綁定到 Python。 此外,此檔案還將宣告在 CUDA (.cu
) 檔案中定義的函數。 然後,C++ 函數將進行一些檢查,並最終將其呼叫轉發到 CUDA 函數。 在 CUDA 檔案中,我們編寫實際的 CUDA 核心。cpp_extension
套件將負責使用 C++ 編譯器(如 gcc
)編譯 C++ 原始碼,並使用 NVIDIA 的 nvcc
編譯器編譯 CUDA 原始碼。 這確保了每個編譯器都能處理它最了解要編譯的檔案。 最終,它們將被連結到一個共享程式庫中,我們可以從 Python 程式碼中使用它。
我們將從 C++ 檔案開始,例如,我們將其命名為 lltm_cuda.cpp
。
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell);
std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights);
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
CHECK_INPUT(old_h);
CHECK_INPUT(old_cell);
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
CHECK_INPUT(grad_h);
CHECK_INPUT(grad_cell);
CHECK_INPUT(input_gate);
CHECK_INPUT(output_gate);
CHECK_INPUT(candidate_cell);
CHECK_INPUT(X);
CHECK_INPUT(gate_weights);
CHECK_INPUT(weights);
return lltm_cuda_backward(
grad_h,
grad_cell,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights,
weights);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward (CUDA)");
m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
}
正如您所看到的,它在很大程度上是樣板程式碼、檢查和轉發到我們將在 CUDA 檔案中定義的函數。 我們將此檔案命名為 lltm_cuda_kernel.cu
(注意 .cu
擴展名!)。 NVCC 可以合理地編譯 C++11,因此我們仍然可以使用 ATen 和 C++ 標準程式庫(但不能使用 torch.h
)。 請注意,setuptools
無法處理具有相同名稱但不同擴展名的檔案,因此如果您使用 setup.py
方法而不是 JIT 方法,您必須為您的 CUDA 檔案提供與您的 C++ 檔案不同的名稱(對於 JIT 方法,lltm.cpp
和 lltm.cu
都可以正常工作)。 讓我們簡單地看一下這個檔案會是什麼樣子
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
template <typename scalar_t>
__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
return 1.0 / (1.0 + exp(-z));
}
在這裡,我們看到了我剛才描述的標頭,以及我們正在使用 CUDA 特定的宣告,例如 __device__
和 __forceinline__
以及函數,例如 exp
。 讓我們繼續介紹我們需要的更多輔助函數
template <typename scalar_t>
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
const auto s = sigmoid(z);
return (1.0 - s) * s;
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
const auto t = tanh(z);
return 1 - (t * t);
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {
return fmax(0.0, z) + fmin(0.0, alpha * (exp(z) - 1.0));
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
const auto e = exp(z);
const auto d_relu = z < 0.0 ? 0.0 : 1.0;
return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0);
}
現在要實際實現一個函數,我們需要兩件事:一個執行我們不想手動編寫並呼叫 CUDA 核心的運算的函數,然後是我們想要加速的部分的實際 CUDA 核心。 對於前向傳遞,第一個函數應該如下所示
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));
const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
auto output_gate = torch::zeros_like(old_cell);
auto candidate_cell = torch::zeros_like(old_cell);
const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.data<scalar_t>(),
old_cell.data<scalar_t>(),
new_h.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
output_gate.data<scalar_t>(),
candidate_cell.data<scalar_t>(),
state_size);
}));
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}
這裡的主要興趣點是 AT_DISPATCH_FLOATING_TYPES
巨集和核心啟動(由 <<<...>>>
指示)。 雖然 ATen 抽象了我們處理的張量的設備和資料類型,但張量在運行時仍然會由具體設備上的具體類型的記憶體支持。 因此,我們需要一種方法來在運行時確定張量的類型,然後有選擇地呼叫具有相應正確類型簽名的函數。 手動完成的話,這(概念上)看起來像這樣
switch (tensor.type().scalarType()) {
case torch::ScalarType::Double:
return function<double>(tensor.data<double>());
case torch::ScalarType::Float:
return function<float>(tensor.data<float>());
...
}
AT_DISPATCH_FLOATING_TYPES
的目的是為我們處理這種分派。 它採用一個類型(在我們的例子中為 gates.type()
)、一個名稱(用於錯誤訊息)和一個 lambda 函數。 在這個 lambda 函數中,類型別名 scalar_t
是可用的,並且定義為張量在該上下文中實際運行時的類型。 因此,如果我們有一個模板函數(我們的 CUDA 核心將是),我們可以實例化它,並使用這個 scalar_t
別名,並且將呼叫正確的函數。 在這種情況下,我們也希望將張量的資料指標作為 scalar_t
類型的指標檢索。 如果您想分派所有類型,而不僅僅是浮點類型(Float
和 Double
),您可以使用 AT_DISPATCH_ALL_TYPES
。
請注意,我們使用普通的 ATen 執行一些運算。 這些運算仍然會在 GPU 上運行,但使用 ATen 的預設實作。 這是合理的,因為 ATen 將對矩陣乘法(例如 addmm
)或卷積使用高度優化的例程,這些例程將更難以實作和改進。
對於核心啟動本身,我們在這裡指定每個 CUDA 區塊將有 1024 個線程,並且整個 GPU 網格被分成多個 1 x 1024
線程的區塊,以滿足我們的需求,即使用一個線程填充每個元件的矩陣。 例如,如果我們的狀態大小為 2048,並且我們的批次大小為 4,我們將總共啟動 4 x 2 = 8
個區塊,每個區塊有 1024 個線程。 如果您以前從未聽說過 CUDA「區塊」或「網格」,那麼關於 CUDA 的入門讀物可能會有所幫助。
實際的 CUDA 核心非常簡單(如果您以前編寫過 GPU 程式)
template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
const scalar_t* __restrict__ gates,
const scalar_t* __restrict__ old_cell,
scalar_t* __restrict__ new_h,
scalar_t* __restrict__ new_cell,
scalar_t* __restrict__ input_gate,
scalar_t* __restrict__ output_gate,
scalar_t* __restrict__ candidate_cell,
size_t state_size) {
const int column = blockIdx.x * blockDim.x + threadIdx.x;
const int index = blockIdx.y * state_size + column;
const int gates_row = blockIdx.y * (state_size * 3);
if (column < state_size) {
input_gate[index] = sigmoid(gates[gates_row + column]);
output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
new_cell[index] =
old_cell[index] + candidate_cell[index] * input_gate[index];
new_h[index] = tanh(new_cell[index]) * output_gate[index];
}
}
這裡最有趣的是,我們能夠完全並行地計算閘極矩陣中每個單獨元件的所有這些 pointwise 運算。 如果您想像必須使用一個巨大的 for
迴圈在序列中對一百萬個元素執行此操作,您就會明白為什麼這會更快得多。
使用訪問器¶
您可以在 CUDA 核心 (kernel) 中看到我們直接使用具有正確類型的指標。實際上,在 CUDA 核心中使用高階、類型不可知的張量會非常沒有效率。
然而,這會犧牲易用性和可讀性,尤其是在處理高維度資料時。在我們的範例中,我們知道連續的 gates
張量有 3 個維度
batch,大小為
batch_size
,步幅為3*state_size
row,大小為
3
,步幅為state_size
index,大小為
state_size
,步幅為1
那麼,我們如何在核心中存取元素 gates[n][row][column]
呢? 事實證明,您需要步幅才能透過一些簡單的算術運算來存取元素。
gates.data<scalar_t>()[n*3*state_size + row*state_size + column]
除了冗長之外,這個表達式還需要顯式地知道步幅,因此必須在核心函數的參數中傳遞。 您可以看到,如果核心函數接受多個具有不同大小的張量,您最終會得到一個非常長的參數列表。
幸運的是,ATen 提供了存取器 (accessors),這些存取器只需進行一次動態檢查,以驗證 Tensor 的類型和維度數量。 然後,存取器公開了一個 API,用於有效率地存取 Tensor 元素,而無需轉換為單個指標。
torch::Tensor foo = torch::rand({12, 12});
// assert foo is 2-dimensional and holds floats.
auto foo_a = foo.accessor<float,2>();
float trace = 0;
for(int i = 0; i < foo_a.size(0); i++) {
// use the accessor foo_a to get tensor data.
trace += foo_a[i][i];
}
存取器物件具有相對高階的介面,包括 .size()
和 .stride()
方法以及多維索引。 .accessor<>
介面設計用於有效率地存取 CPU 張量上的資料。 CUDA 張量的等效介面是 packed_accessor64<>
和 packed_accessor32<>
,它們產生具有 64 位元或 32 位元整數索引的 Packed Accessors。
Accessor 的根本區別在於,Packed Accessor 將大小和步幅資料複製到其結構中,而不是指向它。 這允許我們將它傳遞給 CUDA 核心函數,並在其中使用它的介面。
我們可以設計一個接受 Packed Accessors 而不是指標的函數。
__global__ void lltm_cuda_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell)
讓我們分解這裡使用的模板。 前兩個參數 scalar_t
和 2
與常規 Accessor 相同。 參數 torch::RestrictPtrTraits
表示必須使用 __restrict__
關鍵字。 另請注意,我們使用了 PackedAccessor32
變體,它將大小和步幅儲存在 int32_t
中。 這一點很重要,因為使用 64 位元變體 (PackedAccessor64
) 會使核心變慢。
函數宣告變成
template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < gates.size(2)){
input_gate[n][c] = sigmoid(gates[n][0][c]);
output_gate[n][c] = sigmoid(gates[n][1][c]);
candidate_cell[n][c] = elu(gates[n][2][c]);
new_cell[n][c] =
old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
}
}
實現方式更具可讀性! 然後透過使用主機函數中的 .packed_accessor32<>
方法建立 Packed Accessors 來呼叫此函數。
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);
auto gates = gate_weights.reshape({batch_size, 3, state_size});
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
auto output_gate = torch::zeros_like(old_cell);
auto candidate_cell = torch::zeros_like(old_cell);
const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
new_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
}));
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}
反向傳遞 (backwards pass) 遵循大致相同的模式,我不會進一步詳細說明。
template <typename scalar_t>
__global__ void lltm_cuda_backward_kernel(
torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> d_gates,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_h,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_cell,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < d_gates.size(2)){
const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
const auto d_new_cell =
d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];
d_old_cell[n][c] = d_new_cell;
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;
d_gates[n][0][c] =
d_input_gate * d_sigmoid(gate_weights[n][0][c]);
d_gates[n][1][c] =
d_output_gate * d_sigmoid(gate_weights[n][1][c]);
d_gates[n][2][c] =
d_candidate_cell * d_elu(gate_weights[n][2][c]);
}
}
std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gates,
torch::Tensor weights) {
auto d_old_cell = torch::zeros_like(new_cell);
auto d_gates = torch::zeros_like(gates);
const auto batch_size = new_cell.size(0);
const auto state_size = new_cell.size(1);
const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_backward_cuda", ([&] {
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
d_gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
grad_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
grad_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
}));
auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});
auto d_weights = d_gate_weights.t().mm(X);
auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gate_weights.mm(weights);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
}
將 C++/CUDA 運算與 PyTorch 整合¶
將啟用 CUDA 的運算與 PyTorch 整合也非常簡單。 如果您想編寫一個 setup.py
腳本,它可以如下所示
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='lltm',
ext_modules=[
CUDAExtension('lltm_cuda', [
'lltm_cuda.cpp',
'lltm_cuda_kernel.cu',
])
],
cmdclass={
'build_ext': BuildExtension
})
我們現在使用 CUDAExtension()
,而不是 CppExtension()
。 我們可以只指定 .cu
檔案以及 .cpp
檔案 – 函式庫會為您處理所有相關的麻煩。 JIT 機制甚至更簡單
from torch.utils.cpp_extension import load
lltm = load(name='lltm', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'])
效能比較¶
我們希望透過使用 CUDA 並行化和融合程式碼的逐點運算來提高 LLTM 的效能。 讓我們看看這是否成立。 我們可以執行我先前列出的程式碼來執行基準測試。 我們之前最快的版本是基於 CUDA 的 C++ 程式碼
Forward: 149.802 us | Backward 393.458 us
現在使用我們的自定義 CUDA 核心
Forward: 129.431 us | Backward 304.641 us
效能進一步提升!