CausalVariant¶
- class torch.nn.attention.bias.CausalVariant(value)[原始碼][原始碼]¶
用於注意力機制中的因果變體的列舉。
定義了兩種因果偏差類型
UPPER_LEFT: 代表標準因果注意力的左上三角形偏差。 建構此偏差的等效 pytorch 程式碼是
torch.tril(torch.ones(size, dtype=torch.bool))
例如,使用 shape=(3,4),具體化的偏差張量將會是
[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]
LOWER_RIGHT: 表示右下三角形偏差,包含的值與矩陣的右下角對齊。
建構此偏差的等效 pytorch 程式碼是
diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, )
例如,使用 shape=(3,4),具體化的偏差張量將會是
[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]]
請注意,當查詢和索引鍵/值張量的序列長度相等時,這些變體彼此等效,因為三角形矩陣是正方形。
警告
此枚舉是原型,可能會變更。