• 教學 >
  • (Beta) 使用縮放點積注意力 (SDPA) 實作高效能 Transformer
捷徑

(Beta) 使用縮放點積注意力 (SDPA) 實作高效能 Transformer

建立於:2023 年 3 月 15 日 | 最後更新:2024 年 10 月 09 日 | 最後驗證:2024 年 11 月 05 日

作者: Driss Guessous

摘要

在本教學中,我們想要重點介紹一個新的 torch.nn.functional 函數,它對於實作 transformer 架構可能很有用。該函數名為 torch.nn.functional.scaled_dot_product_attention。有關該函數的詳細描述,請參閱 PyTorch 文件。此函數已經被整合到 torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer 中。

概述

在高層次上,這個 PyTorch 函數根據論文 Attention is all you need 中找到的定義,計算查詢、鍵和值之間的縮放點積注意力 (SDPA)。雖然可以使用現有的函數在 PyTorch 中編寫此函數,但融合實作可以提供比 naive 實作更大的效能優勢。

融合實作

對於 CUDA 張量輸入,該函數將分派到以下實作之一

注意

本教學需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489,  0.3015, -0.3912,  0.9867,  0.3137, -0.0691,
          -1.2593],
         [-1.0882,  0.2506,  0.6491,  0.1360,  0.5238, -0.2448, -0.0820,
          -0.6171],
         [-1.0012,  0.3990,  0.6441, -0.0277,  0.5325, -0.2564, -0.0607,
          -0.6404]],

        [[ 0.6091,  0.0708,  0.6188,  0.3252, -0.1598,  0.4197, -0.2335,
           0.0630],
         [ 0.5285,  0.3890, -0.2649,  0.3706, -0.3839,  0.1963, -0.6242,
           0.2312],
         [ 0.4048,  0.0762,  0.3777,  0.4689, -0.2978,  0.2754, -0.6429,
           0.1037]]], device='cuda:0')

顯式分發器控制

雖然該函數將隱式地分派到三個實作之一,但使用者也可以透過使用上下文管理器來顯式地控制分派。此上下文管理器允許使用者顯式地停用某些實作。如果使用者想要確保該函數確實為其特定輸入使用最快的實作,則可以使用上下文管理器來掃描測量效能。

# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2326.842 microseconds
The math implementation runs in 87127.629 microseconds
The flash attention implementation runs in 2332.811 microseconds
The memory efficient implementation runs in 4343.845 microseconds

硬體依賴性

根據您在上面單元格中運行的機器以及可用的硬體,您的結果可能會有所不同。 - 如果您沒有 GPU 並且在 CPU 上運行,那麼使用 FP32,上下文管理器將不起作用,並且所有三次運行都應返回相似的計時。 - 根據您的顯示卡支援的運算能力,flash attention 或 memory efficient 可能會失敗。

因果自我注意力

以下是一個多頭因果自我注意力區塊的範例實作,靈感來自 Andrej Karpathy NanoGPT 儲存庫。

class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor 和密集張量支援

SDPA 支援 NestedTensor 和密集張量 (Dense tensor) 輸入。NestedTensors 處理輸入為一批變長序列的情況,而無需將每個序列填充到批次中的最大長度。有關 NestedTensors 的更多資訊,請參閱 torch.nestedNestedTensors 教學

import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:228: UserWarning:

The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)

Random NT runs in 563.476 microseconds
Random Dense runs in 947.731 microseconds

將 SDPA 與 torch.compile 搭配使用

隨著 PyTorch 2.0 的發布,引入了一項名為 torch.compile() 的新功能,與 eager 模式相比,它可以提供顯著的效能提升。縮放點積注意力 (Scaled dot product attention) 與 torch.compile() 完全相容。為了示範這一點,讓我們使用 torch.compile() 編譯 CausalSelfAttention 模組,並觀察由此產生的效能提升。

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in  415.033 microseconds
The compiled module runs in  515.460 microseconds

確切的執行時間取決於機器,但是我的機器上的結果是:未編譯的模組在 166.616 微秒內執行,而編譯後的模組在 166.726 微秒內執行。這不是我們所期望的。讓我們更深入地研究一下。PyTorch 提供了一個出色的內建效能分析器,您可以使用它來檢查程式碼的效能特徵。

from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.559ms       101.70%      10.559ms      10.559ms             1
                         Non-Compilied Causal Attention        20.46%       2.262ms        75.78%       8.377ms       8.377ms       0.000us         0.00%      10.383ms      10.383ms             1
                                           aten::linear         1.14%     126.082us        28.04%       3.100ms      61.993us       0.000us         0.00%       7.781ms     155.623us            50
                                           aten::matmul         2.25%     248.826us        24.10%       2.664ms      53.285us       0.000us         0.00%       7.781ms     155.623us            50
                                               aten::mm        15.06%       1.665ms        19.47%       2.152ms      43.047us       7.781ms        74.94%       7.781ms     155.623us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.572ms        53.67%       5.572ms     222.899us            25
                     aten::scaled_dot_product_attention         1.96%     216.958us        18.01%       1.990ms      79.616us       0.000us         0.00%       2.601ms     104.057us            25
              aten::_scaled_dot_product_flash_attention         3.02%     334.253us        16.04%       1.773ms      70.938us       0.000us         0.00%       2.601ms     104.057us            25
                         aten::_flash_attention_forward         3.74%     413.039us        11.22%       1.241ms      49.633us       2.601ms        25.06%       2.601ms     104.057us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.601ms        25.06%       2.601ms     104.057us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.055ms
Self CUDA time total: 10.383ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                              Compiled Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.463ms       100.71%      10.463ms      10.463ms             1
                              Compiled Causal Attention         9.15%       1.015ms        74.41%       8.258ms       8.258ms       0.000us         0.00%      10.389ms      10.389ms             1
                             Torch-Compiled Region: 2/0         8.42%     934.690us        63.27%       7.022ms     280.881us       0.000us         0.00%      10.389ms     415.559us            25
                                       CompiledFunction        25.95%       2.880ms        54.85%       6.087ms     243.494us       0.000us         0.00%      10.389ms     415.559us            25
                                               aten::mm         9.55%       1.059ms        14.12%       1.568ms      31.353us       7.784ms        74.92%       7.784ms     155.676us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.576ms        53.67%       5.576ms     223.034us            25
              aten::_scaled_dot_product_flash_attention         2.17%     240.384us        14.77%       1.640ms      65.593us       0.000us         0.00%       2.605ms     104.208us            25
                         aten::_flash_attention_forward         3.87%     429.608us        10.87%       1.206ms      48.248us       2.605ms        25.08%       2.605ms     104.208us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.605ms        25.08%       2.605ms     104.208us            25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       2.208ms        21.25%       2.208ms      88.318us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.099ms
Self CUDA time total: 10.389ms

先前的程式碼片段會產生一個報告,其中列出了編譯模組和未編譯模組消耗最多 GPU 執行時間的前 10 個 PyTorch 函數。分析顯示,對於這兩個模組,大部分時間都集中在 GPU 上的同一組函數上。這裡的原因是 torch.compile 非常擅長消除與 PyTorch 相關聯的框架開銷。如果您的模型正在啟動大型、高效的 CUDA 核心,在本例中 CausalSelfAttention 正是如此,那麼 PyTorch 的開銷可能會被隱藏。

實際上,您的模組通常不只包含一個 CausalSelfAttention 區塊。在使用 Andrej Karpathy NanoGPT 儲存庫進行實驗時,編譯模組將每個訓練步驟的時間從 6090.49ms 縮短到 3273.17ms! 這是透過 NanoGPT 在莎士比亞資料集上進行訓練的 commit:ae3a8d5 完成的。

將 SDPA 與 attn_bias 子類別搭配使用

# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
#    The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
#    is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#

from torch.nn.attention.bias import causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)

print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

結論

在本教學中,我們示範了 torch.nn.functional.scaled_dot_product_attention 的基本用法。我們已經展示了如何使用 sdpa_kernel 環境管理器來斷言在 GPU 上使用特定的實作。此外,我們還建立了一個簡單的 CausalSelfAttention 模組,它可以與 NestedTensor 搭配使用,並且可以透過 torch 進行編譯。在此過程中,我們展示了如何使用效能分析工具來探索使用者定義模組的效能特徵。

腳本的總執行時間: (0 分鐘 7.314 秒)

由 Sphinx-Gallery 產生圖庫

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並取得問題的解答

檢視資源