透過 Pallas 的自訂核心¶
隨著 OpenAI Triton 的興起,自訂核心在 GPU 社群中變得越來越普及,例如 FlashAttention 和 PagedAttention 的引入。 為了在 TPU 領域提供功能對等性,Google 推出了 Pallas。 為了讓 PyTorch/XLA 能夠持續提升 TPU 的效能,我們必須支援自訂核心,而最佳途徑便是透過 Pallas。
假設您已定義一個 Pallas 核心如下
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
請注意,在匯入任何 jax 模組之前,務必先執行 jax_import_guard()
。 否則,程式將在 TPU 上停滯,因為 jax 會鎖定 TPU,導致 torch-xla 無法存取。
採用上述核心以相容於 PyTorch/XLA¶
使用範例
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
# Adopts any Pallas kernel
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
對於簡單的核心,採用方式就像單行程式碼一樣簡單。 若是更複雜的核心,您可以參考我們的 Flash Attention 實作以取得詳細資訊。
使用內建核心¶
除了手動封裝外部 Pallas 核心,還有一些內建核心,其採用作業已由 PyTorch/XLA 完成。 這些內建核心可以像其他 torch.ops 一樣使用。 目前支援的內建核心包括:- FlashAttention -PagedAttention
FlashAttention¶
使用範例¶
# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = flash_attention(q, k, v)
整合範例¶
我們在訓練測試腳本中提供了一個 FlashAttention 整合範例。
PagedAttention¶
使用範例¶
# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=None,
)
整合範例¶
vLLM TPU 整合在此採用 PagedAttention,以便透過 KV 快取進行有效的記憶體管理。
相依性¶
Pallas 整合功能仰賴 JAX 才能運作。 然而,並非所有 JAX 版本都與您安裝的 PyTorch/XLA 相容。 若要安裝正確的 JAX
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html