注意
點擊這裡以下載完整範例程式碼
將使用者定義的 Triton Kernels 與 torch.compile
搭配使用¶
建立於:2024 年 4 月 19 日 | 上次更新:2025 年 1 月 24 日 | 上次驗證:2024 年 11 月 05 日
作者: Oguz Ulgen
使用者定義的 Triton kernels 可以用於最佳化模型計算的特定部分。這些 kernels 是使用 Triton 語言編寫的,該語言旨在更容易實現最佳硬體效能。通過將使用者定義的 Triton kernels 與 torch.compile
搭配使用,您可以將這些最佳化的計算整合到您的 PyTorch 模型中,並可能實現顯著的效能改進。
此食譜示範了如何將使用者定義的 Triton kernels 與 torch.compile
搭配使用。
先決條件¶
在開始此食譜之前,請確保您具備以下條件
對
torch.compile
和 Triton 的基本了解。請參閱PyTorch 2.3 或更高版本
支援 Triton 的 GPU
import torch
from torch.utils._triton import has_triton
基本用法¶
在本範例中,我們將使用 Triton 文件中的簡單向量加法 kernel 以及 torch.compile
。如需參考,請參閱 Triton 文件。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X: tensor([ 0.1940, 2.1614, -0.1721, 0.8491], device='cuda:0')
Y: tensor([ 0.1391, -0.1082, -0.7174, 0.7566], device='cuda:0')
is equal to
tensor([ 0.3332, 2.0532, -0.8895, 1.6057], device='cuda:0')
進階用法¶
Triton 的自動調整功能是一個強大的工具,可以自動最佳化 Triton kernels 的配置參數。它會探索一系列可能的配置,並選擇最適合您特定使用案例的配置。
與 torch.compile
搭配使用時,triton.autotune
可以幫助確保您的 PyTorch 模型以盡可能高的效率運行。以下是使用 torch.compile
和 triton.autotune
的範例。
注意
torch.compile
僅支援 triton.autotune
的 configs 和 key 參數。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X: tensor([-0.5187, 1.2268, 0.6255, -0.9117], device='cuda:0')
Y: tensor([-0.6974, -1.8688, -0.8832, -1.6627], device='cuda:0')
is equal to
tensor([-1.2161, -0.6421, -0.2577, -2.5744], device='cuda:0')
可組合性¶
使用者定義的 Triton kernels 不會自動支援所有 PyTorch 子系統。這可以在以下使用案例中看到
新增 CPU 回退
新增
FlopCounter
公式與張量子類別組合
若要與其他 PyTorch 子系統組合,請使用 torch.library.triton_op
。
triton_op is
是一種結構化的方式,用於定義由一個或多個 Triton kernels 支援的自定義運算符:與常規自定義運算符 (torch.library.custom_op
) 類似,您可以通過 torch.library
指定與 PyTorch 子系統的互動。但是,與 torch.library.custom_op
不同,後者創建了關於 torch.compile
的不透明可調用對象,torch.compile
會追蹤到 triton_op
以應用最佳化。
以下是用於將 Triton kernels 與 PyTorch 整合時要使用的 API 圖表。
Triton kernel(沒有明確的 |
|
|
|
---|---|---|---|
支援推論 |
是 |
是 |
是 |
支援訓練 |
在大多數情況下 |
是 |
是 |
支援 |
是 |
是 |
是 |
支援 |
在大多數情況下 |
在大多數情況下 |
在所有情況下 |
torch.compile 是否追蹤到實作? |
是 |
是 |
否 |
支援 AOTInductor |
是 |
是 |
否 |
支援 PyTorch 子系統,例如 FlopCounterMode、CPU 回退、張量子類別 |
否 |
是 |
是 |
使用 triton_op
包裝 Triton 核心函式¶
使用 torch.library.triton_op
來包裝可能會調用一個或多個 Triton 核心函式的函式。 使用 torch.library.wrap_triton
來包裝對 Triton 核心函式的調用。
from torch.library import triton_op, wrap_triton
@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
def sin_triton(x):
out = torch.empty_like(x)
n_elements = x.numel()
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
您可以透過以下兩種方式之一調用 triton_op
。
x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)
assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())
產生的 triton_op
可與 torch.compile
和 AOTInductor
一起使用。
y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())
新增訓練支援¶
使用 register_autograd
為 triton_op
新增 autograd 公式。 優先使用此方法,而不是使用 torch.autograd.Function
(它與 torch.compile
有各種組合性的潛在問題)。
請注意,backward 必須是 PyTorch 可理解的運算子的組合。 如果您希望 backward 調用 Triton 核心函式,則這些函式也必須包裝在 triton_op
中
@triton.jit
def cos_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.cos(x)
tl.store(out_ptr + offsets, output, mask=mask)
@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_input * mycos(x)
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)
mysin.register_autograd(backward, setup_context=setup_context)
新增 CPU 回退¶
Triton 核心函式無法在 CPU 上執行。 使用 register_kernel
為 triton_op
新增 CPU(或任何其他裝置)回退
回退必須由 PyTorch 運算子組成。
新增 FlopCounter 公式¶
若要指定 triton 核心函式在 PyTorch 的 flop 計數器下報告的浮點運算次數,請使用 register_flop_formula
。
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
numel = 1
for s in x_shape:
numel *= s
return numel
x = torch.randn(3, device="cuda")
FlopCounterMode
需要 tabulate。 在執行以下程式碼之前,請確保您已安裝 tabulate
,或透過執行 pip install tabulate
進行安裝。
限制¶
截至 PyTorch 2.3,torch.compile
中對使用者定義的 Triton 核心函式的支援包括動態形狀、torch.autograd.Function
、JIT inductor 和 AOT inductor。 您可以將這些功能結合在一起,以建構複雜、高效能的模型。
PyTorch 2.6 新增了 torch.library.triton_op
,它在張量子類別和其他進階功能中增加了對使用者定義的 Triton 核心函式的支援。
但是,需要注意一些限制
Triton 功能:雖然
triton.heuristics
可以單獨使用,也可以在triton.autotune
之前使用,但不能在triton.autotune
之後使用。 這表示如果triton.heuristics
和triton.autotune
要一起使用,則必須先使用triton.heuristics
。
結論¶
在本教學中,我們探索了如何將使用者定義的 Triton 核心函式與 torch.compile
結合使用。 我們深入研究了簡單向量加法核心函式的基本用法,以及涉及 Triton 自動調整功能的進階用法。 我們還討論了使用者定義的 Triton 核心函式與其他 PyTorch 功能的組合性,並強調了一些目前的限制。