• 教學 >
  • 透過用巢狀張量和 torch.compile() 取代 nn.Transformer 來加速 PyTorch Transformer
捷徑

透過用巢狀張量和 torch.compile() 取代 nn.Transformer 來加速 PyTorch Transformer

作者: Mikayla Gawarecki

您將學到什麼
  • 瞭解 PyTorch 提供的低階建構區塊,用於建置自訂 Transformer 層(巢狀張量、scaled_dot_product_attentiontorch.compile()FlexAttention

  • 探索如何以上述元件透過 MultiHeadAttention 為例,改善記憶體使用率和效能

  • 使用上述建構區塊探索進階自訂

先決條件
  • PyTorch v.2.6.0 或更高版本

過去幾年,PyTorch 團隊開發了各種低階功能,這些功能組合起來可以建立各種 Transformer 變體。這些包括

  • 具有 torch.jagged 佈局(又名 NJT)的巢狀張量

  • scaled_dot_product_attention

  • torch.compile()

  • FlexAttention

本教學將簡要概述上述技術,並示範如何將它們組合起來,以產生具有改進使用者體驗、彈性和效能的 Transformer 層。

可以觀察到 torch.nn 模組目前提供了各種與 Transformer 相關的層。 特別是,它包括 TransformerEncoderLayerTransformerEncoderTransformerDecoderLayerTransformerDecoderTransformerMultiheadAttention。 這個層系列最初是根據 Attention is All You Need 論文實作的。 與現有的 nn 層相比,本教學中討論的元件提供了改進的使用者體驗、彈性和效能。

本教學適合我嗎?

如果您想知道 torch 函式庫提供了哪些建構區塊來編寫您自己的 Transformer 層和最佳實務,那麼您來對地方了。 請繼續閱讀!

如果您正在尋找熱門 Transformer 架構的開箱即用實作,請注意有許多開放原始碼函式庫提供這些實作,包括

如果您只對高效能的注意力分數修改感興趣,請查看 FlexAttention 部落格,其中包含一個 mask gym

建構區塊介紹

首先,我們將簡要介紹引言中提到的四種技術

巢狀張量推廣了規則稠密張量的形狀,允許使用相同的張量 UX 表示參差不齊大小的資料。 在 Transformer 的上下文中,我們可以將巢狀張量視為表示可變序列長度的工具。 它們消除了對容易出錯的顯式填充和遮罩的需求(想想 nn.MultiHeadAttention 中的 key_padding_mask)。

scaled_dot_product_attention\(\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V\) 的基本元件,它會分派到運算子的融合實作或後備實作。 它在 eager 模式(即使用 PyTorch 的預設模式,其中操作會在遇到時立即執行)中開箱即用,並且還可以與 torch.compile() 無縫整合。 從 2.6 開始,它還將原生提供分組查詢注意力。

torch.compile() 是 2.0 版中引入的編譯器,能夠擷取 PyTorch 程式碼的圖形並對其執行各種最佳化,例如融合操作序列。 具有 torch.jagged 佈局的巢狀張量和 scaled_dot_product_attention 與編譯無縫協同工作。 在 Transformer 的上下文中,使用編譯與巢狀張量和 SDPA 的價值在於編譯可以消除在 eager 模式下看到的框架開銷,並將 Transformer 中的操作序列融合在一起,例如投影和啟用。

FlexAttention 是一種基本元件,允許使用者在 softmax 運算之前修改注意力分數。 它推廣了上述 scaled_dot_product_attention 的加法 B 術語,允許任意計算。 它需要編譯才能獲得良好的效能。

以上建構區塊是「您所需要的一切」(截至 2024 年 10 月)

本節的主要前提是,大多數 Transformer 變體都是 GPT 樣式,由 Embedding、位置編碼、注意力區塊和前饋網路等層組成。 如果我們要嘗試對此空間中的差異進行分類,我們可能會得到類似的結果

  1. 層類型 (激活函數,例如 SwiGLU 等,標準化函數,例如 RMSNorm 等,位置編碼,例如正弦波、旋轉位置編碼等)。

  2. 層的順序,例如在哪裡應用標準化和位置編碼。

  3. 對注意力分數的修改,例如 ALiBi、相對位置偏差等等。

在預編譯器環境中,您可能會編寫一個自定義轉換器,並注意到它可以正常工作,但速度很慢。為了解決這個問題,您可以為特定的操作序列開發一個自定義融合核心。在編譯器環境中,您可以簡單地執行初始步驟,然後編譯並從提高的性能中受益。

MultiheadAttention

請記住,MultiheadAttention 接收 query、key 和 value,並包含一個輸入投影、一個 scaled_dot_product_attention 運算符和一個輸出投影。我們想要在這裡演示的主要重點是用嵌套張量替換填充/遮罩輸入後所帶來的改進。這些改進有三方面:

  • 使用者體驗 請記住,nn.MultiheadAttention 需要 querykeyvalue 為密集的 torch.Tensors。它還提供一個 key_padding_mask,用於遮罩 key 中因批次中不同序列長度而產生的填充 token。由於 nn.MHA 中沒有 query_padding_mask,因此使用者必須注意遮罩/切片輸出,以適當考慮 query 序列長度。NestedTensor 徹底消除了這種容易出錯的填充遮罩的需求。

  • 記憶體 嵌套張量允許您乾淨地表示不同序列長度的批次,而不是實體化具有 [B, S] 填充遮罩的密集 [B, S, D] 張量(其中 B 是批次大小,S 是批次中的最大序列長度,D 是嵌入大小)。因此,輸入和中間激活將使用更少的記憶體。

  • 效能 由於填充未被實體化,並且跳過了對填充的不必要計算,因此效能和記憶體使用率得到了提高。

我們將通過建立在 嵌套張量教學 中的 MultiheadAttention 層上,並將其與 nn.MultiheadAttention 層進行比較來展示以上内容。

import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """

    def __init__(
        self,
        E_q: int,
        E_k: int,
        E_v: int,
        E_total: int,
        nheads: int,
        dropout: float = 0.0,
        bias=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.nheads = nheads
        self.dropout = dropout
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
            self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
        else:
            self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
            self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
            self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.bias = bias

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_mask=None,
        is_causal=False,
    ) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
            key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
            value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
            attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
            is_causal (bool, optional): Whether to apply causal mask. Default: False

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        if self._qkv_same_embed_dim:
            if query is key and key is value:
                result = self.packed_proj(query)
                query, key, value = torch.chunk(result, 3, dim=-1)
            else:
                q_weight, k_weight, v_weight = torch.chunk(
                    self.packed_proj.weight, 3, dim=0
                )
                if self.bias:
                    q_bias, k_bias, v_bias = torch.chunk(
                        self.packed_proj.bias, 3, dim=0
                    )
                else:
                    q_bias, k_bias, v_bias = None, None, None
                query, key, value = (
                    F.linear(query, q_weight, q_bias),
                    F.linear(key, k_weight, k_bias),
                    F.linear(value, v_weight, v_bias),
                )

        else:
            query = self.q_proj(query)
            key = self.k_proj(key)
            value = self.v_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=self.dropout, is_causal=is_causal
        )
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

實用工具

在本節中,我們包含一個實用工具,用於使用 Zipf 分布生成半真實資料,以表示句子長度。這用於生成嵌套的 query、key 和 value 張量。我們還包括一個基準測試實用工具。

import numpy as np


def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)


# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    if query_seq_len_1:
        query = torch.nested.nested_tensor(
            [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],
            layout=torch.jagged,
        )
    else:
        query = torch.nested.nested_tensor(
            [
                torch.randn(l.item(), E_q, dtype=dtype, device=device)
                for l in sentence_lengths
            ],
            layout=torch.jagged,
        )

    key = torch.nested.nested_tensor(
        [
            torch.randn(s.item(), E_k, dtype=dtype, device=device)
            for s in sentence_lengths
        ],
        layout=torch.jagged,
    )

    value = torch.nested.nested_tensor(
        [
            torch.randn(s.item(), E_v, dtype=dtype, device=device)
            for s in sentence_lengths
        ],
        layout=torch.jagged,
    )

    return query, key, value, sentence_lengths


import math
import timeit


def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

現在,我們將演示在 MultiheadAttention 層中使用嵌套張量 + 編譯進行自注意力運算時的效能提升。我們將其與使用填充和遮罩的傳統 nn.MultiheadAttention + 編譯進行比較。

N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device = "cuda"
torch.manual_seed(6)
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
print(
    f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}"
)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

torch.manual_seed(6)
mha_layer = MultiHeadAttention(
    E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda"
)
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(
    E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda"
)

# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(
    vanilla_mha_layer.out_proj.weight.clone().detach()
)
mha_layer.packed_proj.weight = nn.Parameter(
    vanilla_mha_layer.in_proj_weight.clone().detach()
)
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(
    vanilla_mha_layer.in_proj_bias.clone().detach()
)

new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)

# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, query, query, is_causal=True
)
padded_nested_result = nested_result.to_padded_tensor(0.0)

# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf"))
for i, s in enumerate(sentence_lengths):
    attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)

vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(
    padded_query,
    padded_query,
    padded_query,
    attn_mask=attn_mask,
    key_padding_mask=src_key_padding_mask,
    need_weights=False,
    is_causal=True,
)

# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(
    vanilla_mha_layer,
    padded_query,
    padded_query,
    padded_query,
    key_padding_mask=src_key_padding_mask,
    need_weights=False,
    attn_mask=attn_mask,
    is_causal=True,
)

print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print(
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Total sequence length in nested query 10436, max sequence length 128
padded_time=0.01611, padded_peak_memory=3.88 GB
nested_time=0.00234, nested_peak_memory=0.93 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 6.87
Nested peak memory reduction 2.96 GB

為了參考,以下是在 A100 上的一些範例輸出:

padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB

我們也可以看到反向傳遞的相同結果

for i, entry_length in enumerate(sentence_lengths):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

_, padded_bw_time, padded_bw_peak_mem = benchmark(
    lambda: padded_result.sum().backward()
)
_, nested_bw_time, nested_bw_peak_mem = benchmark(
    lambda: padded_nested_result.sum().backward()
)

print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
print(
    f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB"
)

print(
    "Difference in out_proj.weight.grad",
    (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in packed_proj.weight.grad",
    (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in out_proj.bias.grad",
    (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in packed_proj.bias.grad",
    (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)
    .abs()
    .max()
    .item(),
)
padded_bw_time=1.73546, padded_bw_peak_mem=4.69 GB
nested_bw_time=0.11116, nested_bw_peak_mem=3.14 GB
Nested backward speedup: 15.61
Nested backward peak memory reduction 1.55 GB
Difference in out_proj.weight.grad 0.000396728515625
Difference in packed_proj.weight.grad 0.00146484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.0029296875

A100 上的範例輸出

padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125

GPT 樣式層

一個基本的 GPT 樣式轉換器層由一個因果自注意力層和一個帶有跳過連接的前饋網路 (FFN) 組成。使用上面的 MultiheadAttention 層實現這一點非常簡單,並且會產生與 nn.TransformerEncoderLayer (帶有 is_causal=True) 等效的結果。

我們在這裡展示了實現其餘 nn 層的範例,但為了簡潔起見,我們將其從本教學中省略。

更進一步

到目前為止,我們已經演示了如何實現一個遵循傳統 nn.MultiheadAttention 的高效能 MultiheadAttention 層。回到我們對轉換器架構修改的分類,請記住我們將修改分類為層類型、層順序和對注意力分數的修改。我們相信更改層類型和層順序(例如將 LayerNorm 替換為 RMSNorm)非常簡單。

在本節中,我們將討論使用上述構建模組的各種功能,包括以下內容:

  • 交叉注意力

  • 完全遮罩的行不再導致 NaNs

  • 修改注意力分數:使用 FlexAttention 和 NJT 的 ALiBi

  • 壓縮投影

交叉注意力

交叉注意力是一種注意力形式,其中 query 和 key/value 張量來自不同的序列。

這方面的一個範例是在 nn.TransformerDecoderLayer 中,其中 query 來自解碼器,而 key/value 來自編碼器。

上面的 MultiheadAttention 層很好地概括了這種情況,query 和 key/value 都使用嵌套張量。

query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)

print(
    f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}"
)
print(
    f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}"
)
out = new_mha_layer(query, key, value, is_causal=False)
Total sequence length in nested query 10617, max sequence length 165
Total sequence length in nested key/value 10176, max sequence length 137

與上述一樣,我們可以將其與原始編譯的 nn.MultiheadAttention 進行比較。

torch.manual_seed(6)
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]

# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(
    padded_query,
    padded_key,
    padded_value,
    key_padding_mask=key_padding_mask,
    need_weights=False,
    is_causal=False,
)

nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, key, value, is_causal=False
)
(padded_result, _), padded_time, padded_peak_memory = benchmark(
    vanilla_mha_layer,
    padded_query,
    padded_key,
    padded_value,
    key_padding_mask=key_padding_mask,
    need_weights=False,
    is_causal=False,
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

print(
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Max difference between vanilla and nested result 0.0
Nested speedup: 4.95
Nested peak memory reduction 1.20 GB

A100 上的範例輸出

Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB

完全遮罩的行不再導致 NaNs

長久以來,nn.MultiheadAttentionscaled_dot_product_attention 存在一個問題:如果某一行完全被遮罩 (masked out),注意力層的輸出會變成 NaN。請參閱 issue。這是因為對空集合進行 softmax 運算未定義。

感謝 這個 PR,這個問題已獲得解決。現在,scaled_dot_product_attention 中與完全遮罩的行對應的輸出將會是 0。對於 nn.MHA 未使用 “fast-path” 的情況,這也適用。

強烈建議使用帶有 NJT 的自定義 MHA 層,而不是 nn.MultiheadAttention 中的現有 “fast-path”,因為 NJT 適當建模不規則性的能力,使其可以正確表達空序列。

FlexAttention + NJT

NJT 也與 FlexAttention 模組組合使用。這是 MultiheadAttention 層的概括,允許對注意力分數進行任意修改。下面的示例採用了 alibi_mod,它實現了來自 attention gymALiBi,並將其與巢狀輸入張量一起使用。

from torch.nn.attention.flex_attention import flex_attention


def generate_alibi_bias(H: int):
    """Returns an alibi bias score_mod given the number of heads H
    Args:
        H: number of heads
    Returns:
        alibi_bias: alibi bias score_mod
    """

    def alibi_mod(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / H))
        bias = (q_idx - kv_idx) * scale
        return score + bias

    return alibi_mod


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

此外,還可以透過 create_nested_block_mask 函式,將 FlexAttentionblock_mask 工具與 NJT 一起使用。 這對於利用遮罩的稀疏性來加速注意力計算非常有用。 特別是,該函數為所有組合到 NJT 中的可變長度序列的 “堆疊序列” 建立一個稀疏區塊遮罩,同時正確地遮罩掉序列間的注意力。 在下面的例子中,我們展示如何使用此工具建立因果區塊遮罩。

from torch.nn.attention.flex_attention import create_nested_block_mask


def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex = flex_attention(query, key, value, block_mask=block_mask)

打包投影 (Packed Projection)

打包投影是一種技術,它利用了這樣一個事實:當投影的輸入(矩陣乘法)相同(自注意力)時,我們可以將投影權重和偏差打包成單個張量。當個別投影受記憶體限制而不是受計算限制時,它特別有用。我們將在此處展示兩個範例

  • MultiheadAttention 的輸入投影

  • Transformer 層的前饋網路中的 SwiGLU 激活

MultiheadAttention 的輸入投影

在執行自注意力時,querykeyvalue 是相同的張量。這些張量中的每一個都使用 Linear(E_q, E_total) 層進行投影。相反,我們可以將其打包到一個層中,這就是我們在上面的 MultiheadAttention 層中所做的。

讓我們比較一下打包投影與常用方法的效能

class InputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)


class PackedInputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)

    def forward(self, query):
        return torch.chunk(self.packed_proj(query), 3, dim=-1)


B, D, dtype = 256, 8192, torch.bfloat16

torch.set_float32_matmul_precision("high")
in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16))
packed_in_proj = torch.compile(
    PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16)
)

q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
in_proj(q)
packed_in_proj(q)

# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
print(
    f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x"
)
InputProjection: 0.034049 s, PackedInputProjection: 0.032743 s, speedup: 1.04x

Transformer 層的 SwiGLU 前饋網路

Swish-Gated Linear Unit (SwiGLU) 是一種非線性激活函數,在 transformer 層的前饋網路中越來越受歡迎(例如 Llama)。 具有 SwiGLU 激活的前饋網路定義為

class SwiGLUFFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

使用打包投影實現此目的的另一種方法是

class PackedSwiGLUFFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)

    def forward(self, x):
        x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
        return self.w2(F.silu(x1) * x3)

我們可以按如下方式比較兩個實現的性能。根據您的硬體,您可能會看到不同的結果。在 A100 上,我看到 D=128 時加速了 1.12 倍。

D = 128

swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16))
packed_swigluffn = torch.compile(
    PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)
)

q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
swigluffn(q)
packed_swigluffn(q)

# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
print(
    f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x"
)
SwiGLUFFN: 0.0010555100000146922 s, PackedSwiGLUFFN: 0.0009598690000984789 s, speedup: 1.10x

擴展範例

我們打算更新本教程,以展示更多如何使用各種高效能構建模組(例如 KV-Caching、Grouped Query Attention 等)的範例。 此外,還有一些使用各種高效能構建模組來實現各種 transformer 架構的優秀範例。 一些例子包括

結論

在本教程中,我們介紹了 PyTorch 為編寫 transformer 層提供的底層構建模組,並示範了如何組合它們的範例。 我們希望本教程能夠教育讀者,PyTorch 用戶可以輕鬆地實現靈活且高效能的 transformer 層。

腳本的總執行時間: ( 1 分鐘 21.016 秒)

由 Sphinx-Gallery 產生

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教程

檢視教程

資源

尋找開發資源並解答您的問題

檢視資源