捷徑

torch.nn.attention.flex_attention

torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None)[source][source]

此函數使用任意注意力分數修改函數來實現縮放點積注意力。

此函數計算 query、key 和 value 張量之間的縮放點積注意力,並使用使用者定義的注意力分數修改函數。 注意力分數修改函數將在 query 和 key 張量之間計算注意力分數後應用。 注意力分數計算方式如下

score_mod 函數應具有以下簽名

def score_mod(
    score: Tensor,
    batch: Tensor,
    head: Tensor,
    q_idx: Tensor,
    k_idx: Tensor
) -> Tensor:
其中
  • score:表示注意力分數的純量張量,與 query、key 和 value 張量具有相同的資料類型和裝置。

  • batchheadq_idxk_idx:純量張量,分別表示批次索引、查詢 head 索引、查詢索引和 key/value 索引。 這些張量應具有 torch.int 資料類型,並且位於與 score 張量相同的裝置上。

參數
  • query (Tensor) – Query 張量;形狀 (B,Hq,L,E)(B, Hq, L, E)

  • key (Tensor) – Key 張量;形狀 (B,Hkv,S,E)(B, Hkv, S, E)

  • value (Tensor) – Value 張量;形狀 (B,Hkv,S,Ev)(B, Hkv, S, Ev)

  • score_mod (Optional[Callable]) – 修改注意力分數的函數。 預設情況下,不套用 score_mod。

  • block_mask (Optional[BlockMask]) – BlockMask 物件,用於控制注意力的區塊稀疏模式。

  • scale (Optional[float]) – 在 softmax 之前套用的縮放因子。 如果為 none,則預設值設定為 1E\frac{1}{\sqrt{E}}

  • enable_gqa (bool) – 如果設定為 True,則啟用分組查詢注意力 (GQA) 並將 key/value head 廣播到 query head。

  • return_lse (bool) – 是否傳回注意力分數的 logsumexp。 預設值為 False。

  • kernel_options (選擇性[Dict[str, Any]]) – 傳遞到 Triton 核心的選項。

回傳值

注意力輸出;形狀 (B,Hq,L,Ev)(B, Hq, L, Ev)

回傳類型

output (Tensor)

形狀圖例
  • N:Batch size...:Any number of other batch dimensions (optional)N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}

  • S:Source sequence lengthS: \text{Source sequence length}

  • L:Target sequence lengthL: \text{Target sequence length}

  • E:Embedding dimension of the query and keyE: \text{Embedding dimension of the query and key}

  • Ev:Embedding dimension of the valueEv: \text{Embedding dimension of the value}

警告

torch.nn.attention.flex_attention 是 PyTorch 中的原型功能。 請期待未來 PyTorch 版本中更穩定的實現。 閱讀更多關於功能分類的資訊: https://pytorch.dev.org.tw/blog/pytorch-feature-classification-changes/#prototype

BlockMask 工具

torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[原始碼][原始碼]

此函數從 mask_mod 函數建立區塊遮罩元組。

參數
  • mask_mod (Callable) – mask_mod 函數。這是一個可呼叫物件,定義注意力機制的遮罩模式。 它接受四個參數:b(批次大小)、h(標頭數量)、q_idx(查詢索引)和 kv_idx(鍵/值索引)。 它應該傳回一個布林張量,指示允許哪些注意力連接 (True) 或遮罩掉哪些連接 (False)。

  • B (int) – 批次大小。

  • H (int) – 查詢標頭的數量。

  • Q_LEN (int) – 查詢的序列長度。

  • KV_LEN (int) – 鍵/值的序列長度。

  • device (str) – 執行遮罩建立的裝置。

  • BLOCK_SIZE ( int Tuple[int, int] ) – 用於區塊遮罩的區塊大小。 如果提供單一整數,則會同時用於查詢和鍵/值。

回傳值

一個包含區塊遮罩資訊的 BlockMask 物件。

回傳類型

BlockMask

使用範例
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda")
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
output = flex_attention(query, key, value, block_mask=block_mask)
torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[原始碼][原始碼]

此函數從 mod_fn 函數建立遮罩張量。

參數
  • mod_fn (Union[_score_mod_signature, _mask_mod_signature]) – 修改注意力分數的函數。

  • B (int) – 批次大小。

  • H (int) – 查詢標頭的數量。

  • Q_LEN (int) – 查詢的序列長度。

  • KV_LEN (int) – 鍵/值的序列長度。

  • device (str) – 執行遮罩建立的裝置。

回傳值

一個形狀為 (B, H, M, N) 的遮罩張量。

回傳類型

遮罩 (mask) (Tensor)

torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[原始碼][原始碼]

此函數從 mask_mod 函數建立一個巢狀張量相容的區塊遮罩元組。 傳回的 BlockMask 將位於輸入巢狀張量指定的裝置上。

參數
  • mask_mod (Callable) – mask_mod 函數。這是一個可呼叫物件,定義注意力機制的遮罩模式。 它接受四個參數:b(批次大小)、h(標頭數量)、q_idx(查詢索引)和 kv_idx(鍵/值索引)。 它應該傳回一個布林張量,指示允許哪些注意力連接 (True) 或遮罩掉哪些連接 (False)。

  • B (int) – 批次大小。

  • H (int) – 查詢標頭的數量。

  • q_nt (torch.Tensor) – Jagged layout 巢狀張量 (NJT),用於定義查詢的序列長度結構。 將建構區塊遮罩以對序列長度為 sum(S) 的「堆疊序列」進行操作,其中 S 來自 NJT。

  • kv_nt (torch.Tensor) – Jagged layout 巢狀張量 (NJT),用於定義鍵/值的序列長度結構,允許交叉注意力。 將建構區塊遮罩以對序列長度為 sum(S) 的「堆疊序列」進行操作,其中 S 來自 NJT。 如果此值為 None,則 q_nt 也會用於定義鍵/值的結構。 預設值:None

  • BLOCK_SIZE ( int Tuple[int, int] ) – 用於區塊遮罩的區塊大小。 如果提供單一整數,則會同時用於查詢和鍵/值。

回傳值

一個包含區塊遮罩資訊的 BlockMask 物件。

回傳類型

BlockMask

使用範例
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
query = torch.nested.nested_tensor(..., layout=torch.jagged)
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
query = torch.nested.nested_tensor(..., layout=torch.jagged)
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# cross attention case: pass both query and key/value NJTs
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
torch.nn.attention.flex_attention.and_masks(*mask_mods)[原始碼][原始碼]

傳回一個 mask_mod,它是提供的 mask_mods 的交集

回傳類型

Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]

torch.nn.attention.flex_attention.or_masks(*mask_mods)[原始碼][原始碼]

傳回一個 mask_mod,它是提供的 mask_mods 的聯集

回傳類型

Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]

torch.nn.attention.flex_attention.noop_mask(batch, head, token_q, token_kv)[原始碼][原始碼]

傳回一個 noop mask_mod

回傳類型

Tensor

BlockMask

class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[source][source]

BlockMask是我們用於表示塊稀疏(block-sparse)注意力遮罩的格式。它有點像是BCSR格式和非稀疏格式的混合。

基本概念

塊稀疏遮罩表示,不是表示遮罩中個別元素的稀疏性,而是只有當KV_BLOCK_SIZE x Q_BLOCK_SIZE塊中的每個元素都是稀疏的時,該塊才被認為是稀疏的。這與硬體非常吻合,因為硬體通常需要執行連續的加載和計算。

此格式主要針對 1. 簡單性和 2. 核心效率進行了優化。值得注意的是,它並未針對大小進行優化,因為此遮罩始終會縮小 KV_BLOCK_SIZE * Q_BLOCK_SIZE 倍。如果擔心大小,則可以通過增加塊大小來減小張量的大小。

我們格式的要點是

num_blocks_in_row: Tensor[ROWS]: 描述每行中存在的塊數。

col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: col_indices[i]是第i行的塊位置序列。此行中在col_indices[i][num_blocks_in_row[i]]之後的值未定義。

例如,要從此格式重建原始張量

dense_mask = torch.zeros(ROWS, COLS)
for row in range(ROWS):
    for block_idx in range(num_blocks_in_row[row]):
        dense_mask[row, col_indices[row, block_idx]] = 1

值得注意的是,此格式使沿遮罩實現縮減變得更容易。

詳細資訊

我們格式的基本概念僅需要kv_num_blocks和kv_indices。但是,在此物件上最多有 8 個張量。這代表 4 對

1. (kv_num_blocks, kv_indices): 用於attention的前向傳遞,因為我們沿 KV 維度進行縮減。

2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): 這是可選的,純粹是一種優化。事實證明,將遮罩應用於每個塊的成本相當高!如果我們明確知道哪些塊是「完整的」並且根本不需要遮罩,那麼我們可以跳過將mask_mod應用於這些塊。這需要使用者從score_mod中分離出單獨的mask_mod。對於因果遮罩,這可以加速約 15%。

3. [GENERATED] (q_num_blocks, q_indices): 後向傳遞需要,因為計算 dKV 需要沿 Q 維度的遮罩進行迭代。這些是從 1 自動產生的。

4. [GENERATED] (full_q_num_blocks, full_q_indices): 與上述相同,但用於後向傳遞。這些是從 2 自動產生的。

BLOCK_SIZE: Tuple[int, int]
as_tuple(flatten=True)[source][source]

傳回 BlockMask 的屬性元組。

參數

flatten (bool) – 如果為 True,它將扁平化 (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 的元組

classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None)[source][source]

從鍵值塊資訊建立 BlockMask 實例。

參數
  • kv_num_blocks (Tensor) – 每個 Q_BLOCK_SIZE 行平鋪(row tile)中的 kv_blocks 數。

  • kv_indices (Tensor) – 每個 Q_BLOCK_SIZE 行平鋪中的鍵值塊索引。

  • full_kv_num_blocks (Optional[Tensor]) – 每個 Q_BLOCK_SIZE 行平鋪中的完整 kv_blocks 數。

  • full_kv_indices (Optional[Tensor]) – 每個 Q_BLOCK_SIZE 行平鋪中的完整鍵值塊索引。

  • BLOCK_SIZE (Union[int, Tuple[int, int]]) – KV_BLOCK_SIZE x Q_BLOCK_SIZE 平鋪的大小。

  • mask_mod (Optional[Callable]) – 修改遮罩的函數 (選擇性[Callable])。

回傳值

透過 _transposed_ordered 產生的包含完整 Q 資訊的實例。

回傳類型

BlockMask

引發
full_kv_indices: Optional[Tensor]
full_kv_num_blocks: Optional[Tensor]
full_q_indices: Optional[Tensor]
full_q_num_blocks: Optional[Tensor]
kv_indices: Tensor
kv_num_blocks: Tensor
mask_mod: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
numel()[source][source]

傳回遮罩中元素的數量 (不考慮稀疏性)。

q_indices: Optional[Tensor]
q_num_blocks: Optional[Tensor]
seq_lengths: Tuple[int, int]
property shape
sparsity()[source][source]

計算稀疏區塊(即未計算的區塊)的百分比。

回傳類型

float

to(device)[source][source]

將 BlockMask 移動到指定的裝置。

參數

device (torch.devicestr) – 將 BlockMask 移動到的目標裝置。 可以是 torch.device 物件或字串 (例如,'cpu', 'cuda:0')。

回傳值

一個新的 BlockMask 實例,其中所有 tensor 組件都已移動到指定的裝置。

回傳類型

BlockMask

注意

此方法不會就地修改原始 BlockMask。相反地,它會傳回一個新的 BlockMask 實例,其中個別的 tensor 屬性可能會或可能不會移動到指定的裝置,具體取決於它們目前的裝置位置。

to_dense()[原始碼][原始碼]

傳回一個與 block mask 等效的密集區塊。

回傳類型

Tensor

to_string(grid_size=(20, 20), limit=4)[原始碼][原始碼]

傳回 block mask 的字串表示形式。 非常巧妙。

如果 grid_size 為 None,則印出一個未壓縮的版本。 警告,它可能會非常大!

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和高級開發人員的深入教學

檢視教學

資源

尋找開發資源並取得您問題的解答

檢視資源