MultiheadAttention¶
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
允許模型共同關注來自不同表示子空間的資訊。
注意
請參閱此教學,以深入討論 PyTorch 提供的用於建構您自己的 Transformer 層的高效能建構區塊。
論文中描述的方法:Attention Is All You Need。
Multi-Head Attention 定義為
其中 。
nn.MultiheadAttention
將在可能的情況下使用scaled_dot_product_attention()
的最佳化實作。除了支援新的
scaled_dot_product_attention()
函數之外,為了加速推論,MHA 將在使用 Nested Tensors 支援的快速路徑推論,如果:正在計算自我注意力(即,
query
、key
和value
是相同的張量)。輸入是以批次方式(3D)處理,且
batch_first==True
autograd 已停用(使用
torch.inference_mode
或torch.no_grad
),或者沒有張量引數requires_grad
訓練已停用(使用
.eval()
)add_bias_kv
為False
add_zero_attn
為False
kdim
和vdim
等於embed_dim
如果傳遞了 NestedTensor,則既未傳遞
key_padding_mask
也未傳遞attn_mask
autocast 已停用
如果正在使用最佳化的推論快速路徑實作,則可以為
query
/key
/value
傳遞 NestedTensor,以比使用填充遮罩更有效率地表示填充。 在這種情況下,將會傳回 NestedTensor,並且可以預期額外的加速與輸入中為填充的部分成正比。- 參數
embed_dim – 模型總維度。
num_heads – 並行注意力頭的數量。 請注意,
embed_dim
將在num_heads
之間分割(即,每個頭的維度將為embed_dim // num_heads
)。dropout –
attn_output_weights
上的 dropout 概率。 預設值:0.0
(無 dropout)。bias – 如果指定,則將偏差新增至輸入/輸出投影層。 預設值:
True
。add_bias_kv – 如果指定,則將偏差新增至 dim=0 的 key 和 value 序列。 預設值:
False
。add_zero_attn – 如果指定,則將新的零批次新增至 dim=1 的 key 和 value 序列。 預設值:
False
。kdim – key 的特徵總數。 預設值:
None
(使用kdim=embed_dim
)。vdim – value 的特徵總數。 預設值:
None
(使用vdim=embed_dim
)。batch_first – 如果為
True
,則輸入和輸出張量將以 (batch, seq, feature) 的形式提供。 預設值:False
(seq, batch, feature)。
範例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
使用 query、key 和 value 嵌入來計算 attention 輸出。
支援 padding、遮罩和 attention 權重的可選參數。
- 參數
query (Tensor) – Query 嵌入,形狀為 用於非批次輸入, 當
batch_first=False
或 當batch_first=True
,其中 是目標序列長度, 是批次大小,而 是 query 嵌入維度embed_dim
。 Queries 會與 key-value 對進行比較,以產生輸出。 有關更多詳細資訊,請參閱「Attention Is All You Need」。key (Tensor) – Key embeddings,形狀為:未批次輸入時為 ,當
batch_first=False
時為 ,當batch_first=True
時為 ,其中 是來源序列長度, 是批次大小,而 是 key 的 embedding 維度kdim
。 詳情請參閱 “Attention Is All You Need”。value (Tensor) – 值嵌入 (Value embeddings),形狀為:未批次輸入時為 ,當
batch_first=False
時為 ,當batch_first=True
時為 ,其中 為來源序列長度 (source sequence length), 為批次大小 (batch size),而 為值嵌入維度 (value embedding dimension)vdim
。詳情請參閱 "Attention Is All You Need"。key_padding_mask (Optional[Tensor]) – 如果有指定,則為一個形狀為 的遮罩 (mask),指示在注意力機制中要忽略
key
中的哪些元素 (即將其視為 "padding")。 對於未批次的 query,形狀應為 。支援二元和浮點數遮罩。對於二元遮罩,True
值表示對應的key
值將在注意力機制中被忽略。對於浮點數遮罩,它將直接加到對應的key
值。need_weights (bool) – 如果指定,除了
attn_outputs
之外,還會回傳attn_output_weights
。設定need_weights=False
以使用最佳化的scaled_dot_product_attention
並達到 MHA 的最佳效能。預設值:True
。attn_mask (Optional[Tensor]) – 如果指定,這是一個 2D 或 3D 的遮罩,用於防止 attention 機制關注某些位置。其形狀必須為 或 ,其中 是批次大小 (batch size), 是目標序列長度 (target sequence length),而 是來源序列長度 (source sequence length)。 2D 遮罩將會跨批次廣播 (broadcast),而 3D 遮罩允許批次中的每個項目使用不同的遮罩。支援二元 (binary) 和浮點數 (float) 遮罩。對於二元遮罩,
True
值表示不允許關注對應位置。對於浮點數遮罩,遮罩值將被加到 attention 權重上。如果同時提供了 attn_mask 和 key_padding_mask,它們的類型應該匹配。average_attn_weights (bool) – 如果為 true,表示返回的
attn_weights
應該在所有 heads (多頭注意力機制的頭) 上進行平均。否則,attn_weights
將會針對每個 head 分別提供。請注意,這個 flag 只有在need_weights=True
時才有效。預設值:True
(即跨 heads 平均權重)is_causal (bool) – 如果指定,則應用因果遮罩 (causal mask) 作為 attention 遮罩。 預設值:
False
。 警告:is_causal
提供了一個提示,表明attn_mask
是因果遮罩。 提供不正確的提示可能會導致不正確的執行,包括向前和向後相容性問題。
- 回傳類型
- 輸出
attn_output - Attention 的輸出,當輸入未批次化 (unbatched) 時,形狀為 ,當
batch_first=False
時,形狀為 ,當batch_first=True
時,形狀為 ,其中 是目標序列長度, 是批次大小,而 是 embedding 維度embed_dim
。attn_output_weights - 僅當
need_weights=True
時才會返回。如果average_attn_weights=True
,則返回跨 head 平均的 attention weights,形狀為 (當輸入未批次化時)或 ,其中 是批次大小, 是目標序列長度, 是來源序列長度。如果average_attn_weights=False
,則返回每個 head 的 attention weights,形狀為 (當輸入未批次化時)或 。
注意
對於未批次化的輸入,batch_first 參數會被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[source][source]¶
確定 mask 類型並在必要時合併 mask。
如果只提供一個 mask,則將返回該 mask 和相應的 mask 類型。如果同時提供了兩個 mask,它們將都被擴展到形狀
(batch_size, num_heads, seq_len, seq_len)
,並使用邏輯or
組合,並返回 mask 類型 2 :param attn_mask: 形狀為(seq_len, seq_len)
的 attention mask,mask 類型 0 :param key_padding_mask: 形狀為(batch_size, seq_len)
的 padding mask,mask 類型 1 :param query: 形狀為(batch_size, seq_len, embed_dim)
的 query embeddings- 返回
merged mask mask_type: 合併的 mask 類型 (0, 1, 或 2)
- 回傳類型
merged_mask