快捷方式

MultiheadAttention

class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]
dequantize()[source][source]

將量化的 MHA 轉換回浮點數的工具。

這樣做的動機是,將權重從量化版本中使用的格式轉換回浮點數並非易事。

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]
注意:

請參考 forward() 以獲得更多資訊

參數
  • query (Tensor) – 將查詢和一組鍵-值對應到一個輸出。詳情請參閱 "Attention Is All You Need"。

  • key (Tensor) – 將查詢和一組鍵-值對應到一個輸出。詳情請參閱 "Attention Is All You Need"。

  • value (Tensor) – 將查詢和一組鍵-值對應到一個輸出。詳情請參閱 "Attention Is All You Need"。

  • key_padding_mask (Optional[Tensor]) – 如果提供,則 attention 機制將忽略鍵中指定的填充元素。當給定一個二元遮罩且值為 True 時,attention 層上的相應值將被忽略。

  • need_weights (bool) – 輸出 attn_output_weights。

  • attn_mask (Optional[Tensor]) – 2D 或 3D 遮罩,用於防止對某些位置的 attention。2D 遮罩將廣播到所有批次,而 3D 遮罩允許為每個批次的條目指定不同的遮罩。

返回類型

Tuple[Tensor, Optional[Tensor]]

形狀
  • 輸入

  • query:(L,N,E)(L, N, E),其中 L 為目標序列長度,N 為批次大小,E 為嵌入維度。如果 batch_firstTrue,則為 (N,L,E)(N, L, E)

  • key:(S,N,E)(S, N, E),其中 S 為來源序列長度,N 為批次大小,E 為嵌入維度。如果 batch_firstTrue,則為 (N,S,E)(N, S, E)

  • value:(S,N,E)(S, N, E),其中 S 為來源序列長度,N 為批次大小,E 為嵌入維度。如果 batch_firstTrue,則為 (N,S,E)(N, S, E)

  • key_padding_mask:(N,S)(N, S),其中 N 為批次大小,S 為來源序列長度。如果提供的是 BoolTensor,則值為 True 的位置將會被忽略,而值為 False 的位置則不會變更。

  • attn_mask: 2D遮罩 (L,S)(L, S),其中 L 是目標序列長度,S 是來源序列長度。3D遮罩 (Nnumheads,L,S)(N*num_heads, L, S),其中 N 是批次大小,L 是目標序列長度,S 是來源序列長度。attn_mask 確保位置 i 允許關注未遮罩的位置。如果提供 BoolTensor,則具有 True 的位置不允許關注,而 False 值將保持不變。如果提供 FloatTensor,它將被添加到注意力權重中。

  • is_causal: 如果指定,則應用因果遮罩作為注意力遮罩。與提供 attn_mask 互斥。預設值:False

  • average_attn_weights: 如果為 true,表示返回的 attn_weights 應該在 head 之間進行平均。 否則,attn_weights 會為每個 head 單獨提供。 請注意,此標誌僅在 need_weights=True 時才有效。 預設值:True(即,head 之間的平均權重)

  • 輸出

  • attn_output: (L,N,E)(L, N, E),其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。(N,L,E)(N, L, E) 如果 batch_firstTrue

  • attn_output_weights: 如果 average_attn_weights=True,則返回跨 head 平均的注意力權重,形狀為 (N,L,S)(N, L, S),其中 N 是批次大小,L 是目標序列長度,S 是來源序列長度。 如果 average_attn_weights=False,則返回每個 head 的注意力權重,形狀為 (N,numheads,L,S)(N, num_heads, L, S)

文件

存取 PyTorch 的完整開發者文件

查看文件

教學課程

取得初學者和進階開發者的深入教學課程

查看教學課程

資源

尋找開發資源並獲得您問題的解答

查看資源