捷徑

CausalVariant

class torch.nn.attention.bias.CausalVariant(value)[原始碼][原始碼]

用於注意力機制中的因果變體的列舉。

定義了兩種因果偏差類型

UPPER_LEFT: 代表標準因果注意力的左上三角形偏差。 建構此偏差的等效 pytorch 程式碼是

torch.tril(torch.ones(size, dtype=torch.bool))

例如,使用 shape=(3,4),具體化的偏差張量將會是

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]

LOWER_RIGHT: 表示右下三角形偏差,包含的值與矩陣的右下角對齊。

建構此偏差的等效 pytorch 程式碼是

diagonal_offset = size[1] - size[0]
torch.tril(
    torch.ones(size, dtype=torch.bool),
    diagonal=diagonal_offset,
)

例如,使用 shape=(3,4),具體化的偏差張量將會是

[[1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

請注意,當查詢和索引鍵/值張量的序列長度相等時,這些變體彼此等效,因為三角形矩陣是正方形。

警告

此枚舉是原型,可能會變更。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您問題的解答

查看資源