MultiheadAttention¶
- class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
- 注意:
請參考
forward()
以獲得更多資訊
- 參數
query (Tensor) – 將查詢和一組鍵-值對應到一個輸出。詳情請參閱 "Attention Is All You Need"。
key (Tensor) – 將查詢和一組鍵-值對應到一個輸出。詳情請參閱 "Attention Is All You Need"。
value (Tensor) – 將查詢和一組鍵-值對應到一個輸出。詳情請參閱 "Attention Is All You Need"。
key_padding_mask (Optional[Tensor]) – 如果提供,則 attention 機制將忽略鍵中指定的填充元素。當給定一個二元遮罩且值為 True 時,attention 層上的相應值將被忽略。
need_weights (bool) – 輸出 attn_output_weights。
attn_mask (Optional[Tensor]) – 2D 或 3D 遮罩,用於防止對某些位置的 attention。2D 遮罩將廣播到所有批次,而 3D 遮罩允許為每個批次的條目指定不同的遮罩。
- 返回類型
- 形狀
輸入
query:,其中 L 為目標序列長度,N 為批次大小,E 為嵌入維度。如果
batch_first
為True
,則為 。key:,其中 S 為來源序列長度,N 為批次大小,E 為嵌入維度。如果
batch_first
為True
,則為 。value:,其中 S 為來源序列長度,N 為批次大小,E 為嵌入維度。如果
batch_first
為True
,則為 。key_padding_mask:,其中 N 為批次大小,S 為來源序列長度。如果提供的是 BoolTensor,則值為
True
的位置將會被忽略,而值為False
的位置則不會變更。attn_mask: 2D遮罩 ,其中 L 是目標序列長度,S 是來源序列長度。3D遮罩 ,其中 N 是批次大小,L 是目標序列長度,S 是來源序列長度。attn_mask 確保位置 i 允許關注未遮罩的位置。如果提供 BoolTensor,則具有
True
的位置不允許關注,而False
值將保持不變。如果提供 FloatTensor,它將被添加到注意力權重中。is_causal: 如果指定,則應用因果遮罩作為注意力遮罩。與提供 attn_mask 互斥。預設值:
False
。average_attn_weights: 如果為 true,表示返回的
attn_weights
應該在 head 之間進行平均。 否則,attn_weights
會為每個 head 單獨提供。 請注意,此標誌僅在need_weights=True
時才有效。 預設值:True(即,head 之間的平均權重)輸出
attn_output: ,其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first
是True
。attn_output_weights: 如果
average_attn_weights=True
,則返回跨 head 平均的注意力權重,形狀為 ,其中 N 是批次大小,L 是目標序列長度,S 是來源序列長度。 如果average_attn_weights=False
,則返回每個 head 的注意力權重,形狀為 。