快捷鍵

torch.nn.attention.bias.CausalBias

class torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[來源][來源]

代表因果注意力模式的偏差。 關於偏差結構的概述,請參閱 CausalVariant 列舉。

此類別用於定義因果(三角形)注意力偏差。 為了構建偏差,存在兩個工廠函數:causal_upper_left()causal_lower_right()

範例

from torch.nn.attention.bias import causal_lower_right

bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8

# Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)

q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(q, k, v, attn_bias)

警告

此類別為原型,可能會變更。

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源