捷徑

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]

允許模型共同關注來自不同表示子空間的資訊。

注意

請參閱此教學,以深入討論 PyTorch 提供的用於建構您自己的 Transformer 層的高效能建構區塊。

論文中描述的方法:Attention Is All You Need

Multi-Head Attention 定義為

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

其中 headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

nn.MultiheadAttention 將在可能的情況下使用 scaled_dot_product_attention() 的最佳化實作。

除了支援新的 scaled_dot_product_attention() 函數之外,為了加速推論,MHA 將在使用 Nested Tensors 支援的快速路徑推論,如果:

  • 正在計算自我注意力(即,querykeyvalue 是相同的張量)。

  • 輸入是以批次方式(3D)處理,且 batch_first==True

  • autograd 已停用(使用 torch.inference_modetorch.no_grad),或者沒有張量引數 requires_grad

  • 訓練已停用(使用 .eval()

  • add_bias_kvFalse

  • add_zero_attnFalse

  • kdimvdim 等於 embed_dim

  • 如果傳遞了 NestedTensor,則既未傳遞 key_padding_mask 也未傳遞 attn_mask

  • autocast 已停用

如果正在使用最佳化的推論快速路徑實作,則可以為 query/key/value 傳遞 NestedTensor,以比使用填充遮罩更有效率地表示填充。 在這種情況下,將會傳回 NestedTensor,並且可以預期額外的加速與輸入中為填充的部分成正比。

參數
  • embed_dim – 模型總維度。

  • num_heads – 並行注意力頭的數量。 請注意,embed_dim 將在 num_heads 之間分割(即,每個頭的維度將為 embed_dim // num_heads)。

  • dropoutattn_output_weights 上的 dropout 概率。 預設值:0.0(無 dropout)。

  • bias – 如果指定,則將偏差新增至輸入/輸出投影層。 預設值:True

  • add_bias_kv – 如果指定,則將偏差新增至 dim=0 的 key 和 value 序列。 預設值:False

  • add_zero_attn – 如果指定,則將新的零批次新增至 dim=1 的 key 和 value 序列。 預設值:False

  • kdim – key 的特徵總數。 預設值:None(使用 kdim=embed_dim)。

  • vdim – value 的特徵總數。 預設值:None(使用 vdim=embed_dim)。

  • batch_first – 如果為 True,則輸入和輸出張量將以 (batch, seq, feature) 的形式提供。 預設值:False(seq, batch, feature)。

範例

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]

使用 query、key 和 value 嵌入來計算 attention 輸出。

支援 padding、遮罩和 attention 權重的可選參數。

參數
  • query (Tensor) – Query 嵌入,形狀為 (L,Eq)(L, E_q) 用於非批次輸入,(L,N,Eq)(L, N, E_q)batch_first=False(N,L,Eq)(N, L, E_q)batch_first=True,其中 LL 是目標序列長度,NN 是批次大小,而 EqE_q 是 query 嵌入維度 embed_dim。 Queries 會與 key-value 對進行比較,以產生輸出。 有關更多詳細資訊,請參閱「Attention Is All You Need」。

  • key (Tensor) – Key embeddings,形狀為:未批次輸入時為 (S,Ek)(S, E_k),當 batch_first=False 時為 (S,N,Ek)(S, N, E_k),當 batch_first=True 時為 (N,S,Ek)(N, S, E_k),其中 SS 是來源序列長度,NN 是批次大小,而 EkE_k 是 key 的 embedding 維度 kdim。 詳情請參閱 “Attention Is All You Need”。

  • value (Tensor) – 值嵌入 (Value embeddings),形狀為:未批次輸入時為 (S,Ev)(S, E_v),當 batch_first=False 時為 (S,N,Ev)(S, N, E_v),當 batch_first=True 時為 (N,S,Ev)(N, S, E_v),其中 SS 為來源序列長度 (source sequence length),NN 為批次大小 (batch size),而 EvE_v 為值嵌入維度 (value embedding dimension) vdim。詳情請參閱 "Attention Is All You Need"。

  • key_padding_mask (Optional[Tensor]) – 如果有指定,則為一個形狀為 (N,S)(N, S) 的遮罩 (mask),指示在注意力機制中要忽略 key 中的哪些元素 (即將其視為 "padding")。 對於未批次的 query,形狀應為 (S)(S)。支援二元和浮點數遮罩。對於二元遮罩,True 值表示對應的 key 值將在注意力機制中被忽略。對於浮點數遮罩,它將直接加到對應的 key 值。

  • need_weights (bool) – 如果指定,除了 attn_outputs 之外,還會回傳 attn_output_weights。設定 need_weights=False 以使用最佳化的 scaled_dot_product_attention 並達到 MHA 的最佳效能。預設值:True

  • attn_mask (Optional[Tensor]) – 如果指定,這是一個 2D 或 3D 的遮罩,用於防止 attention 機制關注某些位置。其形狀必須為 (L,S)(L, S)(Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S),其中 NN 是批次大小 (batch size),LL 是目標序列長度 (target sequence length),而 SS 是來源序列長度 (source sequence length)。 2D 遮罩將會跨批次廣播 (broadcast),而 3D 遮罩允許批次中的每個項目使用不同的遮罩。支援二元 (binary) 和浮點數 (float) 遮罩。對於二元遮罩,True 值表示不允許關注對應位置。對於浮點數遮罩,遮罩值將被加到 attention 權重上。如果同時提供了 attn_mask 和 key_padding_mask,它們的類型應該匹配。

  • average_attn_weights (bool) – 如果為 true,表示返回的 attn_weights 應該在所有 heads (多頭注意力機制的頭) 上進行平均。否則,attn_weights 將會針對每個 head 分別提供。請注意,這個 flag 只有在 need_weights=True 時才有效。預設值: True (即跨 heads 平均權重)

  • is_causal (bool) – 如果指定,則應用因果遮罩 (causal mask) 作為 attention 遮罩。 預設值: False。 警告: is_causal 提供了一個提示,表明 attn_mask 是因果遮罩。 提供不正確的提示可能會導致不正確的執行,包括向前和向後相容性問題。

回傳類型

Tuple[Tensor, Optional[Tensor]]

輸出
  • attn_output - Attention 的輸出,當輸入未批次化 (unbatched) 時,形狀為 (L,E)(L, E),當 batch_first=False 時,形狀為 (L,N,E)(L, N, E),當 batch_first=True 時,形狀為 (N,L,E)(N, L, E),其中 LL 是目標序列長度,NN 是批次大小,而 EE 是 embedding 維度 embed_dim

  • attn_output_weights - 僅當 need_weights=True 時才會返回。如果 average_attn_weights=True,則返回跨 head 平均的 attention weights,形狀為 (L,S)(L, S)(當輸入未批次化時)或 (N,L,S)(N, L, S),其中 NN 是批次大小,LL 是目標序列長度,SS 是來源序列長度。如果 average_attn_weights=False,則返回每個 head 的 attention weights,形狀為 (num_heads,L,S)(\text{num\_heads}, L, S)(當輸入未批次化時)或 (N,num_heads,L,S)(N, \text{num\_heads}, L, S)

注意

對於未批次化的輸入,batch_first 參數會被忽略。

merge_masks(attn_mask, key_padding_mask, query)[source][source]

確定 mask 類型並在必要時合併 mask。

如果只提供一個 mask,則將返回該 mask 和相應的 mask 類型。如果同時提供了兩個 mask,它們將都被擴展到形狀 (batch_size, num_heads, seq_len, seq_len),並使用邏輯 or 組合,並返回 mask 類型 2 :param attn_mask: 形狀為 (seq_len, seq_len) 的 attention mask,mask 類型 0 :param key_padding_mask: 形狀為 (batch_size, seq_len) 的 padding mask,mask 類型 1 :param query: 形狀為 (batch_size, seq_len, embed_dim) 的 query embeddings

返回

merged mask mask_type: 合併的 mask 類型 (0, 1, 或 2)

回傳類型

merged_mask

文件

取得 PyTorch 的完整開發者文件

查看文件

教學課程

取得針對初學者和進階開發者的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源