快捷鍵

torch.nn.functional.scaled_dot_product_attention

torch.nn.functional.scaled_dot_product_attention()
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,

is_causal=False, scale=None, enable_gqa=False) -> Tensor

使用 query、key 和 value tensors 計算縮放點積注意力,如果傳遞了可選的注意力遮罩,則使用該遮罩,如果指定的機率大於 0.0,則應用 dropout。 可選的 scale 參數只能指定為關鍵字參數。

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

警告

此功能為 beta 版,可能會發生變更。

警告

此函數總是根據指定的 dropout_p 參數套用 dropout。若要在評估期間停用 dropout,請務必在呼叫此函數的模組不在訓練模式時,傳遞 0.0 的值。

例如

class MyModel(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, ...):
        return F.scaled_dot_product_attention(...,
            dropout_p=(self.p if self.training else 0.0))

注意

目前支援三種縮放點積注意力 (scaled dot product attention) 的實作方式:

使用 CUDA 後端時,此函數可能會呼叫最佳化的核心 (kernel) 以提高效能。對於所有其他後端,將使用 PyTorch 實作。

所有實作預設皆為啟用。縮放點積注意力會嘗試根據輸入自動選擇最優化的實作。為了提供對所使用實作更細緻的控制,提供了以下函數來啟用和停用實作。建議使用 context manager:

每個融合核心 (fused kernel) 都有特定的輸入限制。如果使用者需要使用特定的融合實作,請使用 torch.nn.attention.sdpa_kernel() 停用 PyTorch C++ 實作。如果沒有可用的融合實作,將會發出警告,說明融合實作無法運行的原因。

由於融合浮點運算的本質,此函數的輸出可能會因選擇的後端核心而異。c++ 實作支援 torch.float64,並且可以在需要更高精度時使用。對於 math 後端,如果輸入為 torch.half 或 torch.bfloat16,則所有中間值都將保留在 torch.float 中。

有關更多信息,請參閱 數值準確度

Grouped Query Attention (GQA) 是一個實驗性功能。目前僅適用於 CUDA tensor 上的 Flash_attention 和 math kernel,並且不支援 Nested tensor。GQA 的約束條件:

  • number_of_heads_query % number_of_heads_key_value == 0 且

  • number_of_heads_key == number_of_heads_value

注意

在某些情況下,當在 CUDA 裝置上給定 tensors 並使用 CuDNN 時,此運算子可能會選擇非決定性的演算法來提高效能。如果這是不可取的,您可以嘗試通過設定 torch.backends.cudnn.deterministic = True 使運算具有確定性(可能會以效能為代價)。有關更多信息,請參閱可重複性

參數
  • query (Tensor) – Query tensor; shape (N,...,Hq,L,E)(N, ..., Hq, L, E)

  • key (Tensor) – Key tensor; shape (N,...,H,S,E)(N, ..., H, S, E)

  • value (Tensor) – Value tensor; shape (N,...,H,S,Ev)(N, ..., H, S, Ev)

  • attn_mask (optional Tensor) – Attention mask; shape must be broadcastable to the shape of attention weights, which is (N,...,L,S)(N,..., L, S)。支援兩種型別的 mask。布林 mask,其中 True 值表示該元素應該參與注意力。與 query、key、value 相同型別的浮點 mask,會加到注意力分數中。

  • dropout_p (float) – Dropout 機率;如果大於 0.0,則套用 dropout。

  • is_causal (bool) – 如果設為 True,則當遮罩是方陣時,注意力遮罩會是下三角矩陣。當遮罩是非方陣時,由於對齊,注意力遮罩的形式為左上角的因果偏誤 (請參閱 torch.nn.attention.bias.CausalBias)。如果同時設定 attn_mask 和 is_causal,則會擲回錯誤。

  • scale (optional python:float, keyword-only) – 在 softmax 之前套用的縮放因子。 如果為 None,則預設值設定為 1E\frac{1}{\sqrt{E}}

  • enable_gqa (bool) – 如果設為 True,則啟用分組查詢注意力 (Grouped Query Attention, GQA),預設為 False。

傳回

注意力輸出;形狀為 (N,...,Hq,L,Ev)(N, ..., Hq, L, Ev)

傳回類型

output (Tensor)

形狀圖例
  • N:Batch size...:Any number of other batch dimensions (optional)N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}

  • S:Source sequence lengthS: \text{Source sequence length}

  • L:Target sequence lengthL: \text{Target sequence length}

  • E:Embedding dimension of the query and keyE: \text{Embedding dimension of the query and key}

  • Ev:Embedding dimension of the valueEv: \text{Embedding dimension of the value}

  • Hq:Number of heads of queryHq: \text{Number of heads of query}

  • H:Number of heads of key and valueH: \text{Number of heads of key and value}

範例

>>> # Optionally use the context manager to ensure one of the fused kernels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
>>>     F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3
>>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.MATH]):
>>>     F.scaled_dot_product_attention(query,key,value,enable_gqa=True)

文件

取得 PyTorch 完整的開發者文件

檢視文件

教學

取得初學者和進階開發人員的深度教學

檢視教學

資源

尋找開發資源並取得您問題的解答

檢視資源