torchtext.nn¶
MultiheadAttentionContainer¶
- class torchtext.nn.MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[原始碼]¶
- __init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False) 無 [原始碼]¶
一個多頭注意力容器
- 參數:
nhead – 多頭注意力模型中的頭數
in_proj_container – 一個包含多頭輸入投影線性層的容器(又稱 nn.Linear)。
attention_layer – 自訂注意力層。從 MHA 容器發送到注意力層的輸入形狀為 (…, L, N * H, E / H)(查詢)和 (…, S, N * H, E / H)(鍵/值),而注意力層的輸出形狀預計為 (…, L, N * H, E / H)。如果使用者希望整體 MultiheadAttentionContainer 具有廣播功能,則 attention_layer 需要支援廣播。
out_proj – 多頭輸出投影層(又稱 nn.Linear)。
batch_first – 如果為
True
,則輸入和輸出張量將以 (…, N, L, E) 的形式提供。預設值:False
- 範例:
>>> import torch >>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor] [原始碼]¶
- 參數:
query (Tensor) – 注意力函數的查詢。有關詳細資訊,請參閱「Attention Is All You Need」。
key (Tensor) – 注意力函數的鍵。有關詳細資訊,請參閱「Attention Is All You Need」。
value (Tensor) – 注意力函數的值。有關詳細資訊,請參閱「Attention Is All You Need」。
attn_mask (BoolTensor, 選用) – 防止注意力集中在某些位置的 3D 遮罩。
bias_k (Tensor, 選用) – 要在序列維度(dim=-3)處新增到鍵的一個以上鍵和值序列。這些用於增量解碼。使用者應提供
bias_v
。bias_v (Tensor, 選用) – 要在序列維度(dim=-3)處新增到值的一個以上鍵和值序列。這些用於增量解碼。使用者也應提供
bias_k
。
形狀
輸入
查詢:\((..., L, N, E)\)
鍵:\((..., S, N, E)\)
值:\((..., S, N, E)\)
attn_mask、bias_k 和 bias_v:與注意力層中對應參數的形狀相同。
輸出
attn_output:\((..., L, N, E)\)
attn_output_weights:\((N * H, L, S)\)
備註:查詢/鍵/值輸入可以選擇性地具有超過三個維度(用於廣播目的)。MultiheadAttentionContainer 模組將對最後三個維度進行操作。
其中 L 是目標長度,S 是序列長度,H 是注意力頭的數量,N 是批次大小,E 是嵌入維度。
InProjContainer¶
- class torchtext.nn.InProjContainer(query_proj, key_proj, value_proj)[原始碼]¶
- __init__(query_proj, key_proj, value_proj) 無 [原始碼]¶
一個用於在 MultiheadAttention 中投影查詢/鍵/值的 in-proj 容器。此模組發生在將投影的查詢/鍵/值重塑為多個頭部之前。請參閱 Attention Is All You Need 論文圖 2 中多頭注意力機制的線性層(底部)。另請查看 torchtext.nn.MultiheadAttentionContainer 中的使用範例。
- 參數:
query_proj – 查詢的投影層。典型的投影層是 torch.nn.Linear。
key_proj – 鍵的投影層。典型的投影層是 torch.nn.Linear。
value_proj – 值的投影層。典型的投影層是 torch.nn.Linear。
- forward(query: Tensor, key: Tensor, value: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
使用 in-proj 層投影輸入序列。查詢/鍵/值分別傳遞到 query/key/value_proj 的正向函數。
- 參數:
query (Tensor) – 要投影的查詢。
key (Tensor) – 要投影的鍵。
value (Tensor) – 要投影的值。
- 範例:
>>> import torch >>> from torchtext.nn import InProjContainer >>> embed_dim, bsz = 10, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> q = torch.rand((5, bsz, embed_dim)) >>> k = v = torch.rand((6, bsz, embed_dim)) >>> q, k, v = in_proj_container(q, k, v)
ScaledDotProduct¶
- class torchtext.nn.ScaledDotProduct(dropout=0.0, batch_first=False)[source]¶
- __init__(dropout=0.0, batch_first=False) None [source]¶
處理投影的查詢和鍵值對以應用縮放點積注意力。
- 參數:
dropout (float) – 丟棄注意力權重的機率。
batch_first – 如果為
True
,則輸入和輸出張量將以 (batch, seq, feature) 的形式提供。默認為False
- 範例:
>>> import torch, torchtext >>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(21, 256, 3) >>> k = v = torch.randn(21, 256, 3) >>> attn_output, attn_weights = SDP(q, k, v) >>> print(attn_output.shape, attn_weights.shape) torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor] [source]¶
使用投影的鍵值對的縮放點積來更新投影的查詢。
- 參數:
query (Tensor) – 投影的查詢
key (Tensor) – 投影的鍵
value (Tensor) – 投影的值
attn_mask (BoolTensor, 選用) – 防止注意力集中在某些位置的 3D 遮罩。
attn_mask – 3D 遮罩,可防止注意力集中在某些位置。
bias_k (Tensor, 選用) – 要在序列維度(dim=-3)處新增到鍵的一個以上鍵和值序列。這些用於增量解碼。使用者應提供
bias_v
。bias_v (Tensor, 選用) – 要在序列維度(dim=-3)處新增到值的一個以上鍵和值序列。這些用於增量解碼。使用者也應提供
bias_k
。
- 形狀
查詢:\((..., L, N * H, E / H)\)
鍵:\((..., S, N * H, E / H)\)
值:\((..., S, N * H, E / H)\)
- attn_mask:\((N * H, L, S)\),不允許參與注意力計算的位置為
True
而
False
值將保持不變。
- attn_mask:\((N * H, L, S)\),不允許參與注意力計算的位置為
bias_k 和 bias_v:bias:\((1, N * H, E / H)\)
輸出:\((..., L, N * H, E / H)\),\((N * H, L, S)\)
- 注意:查詢/鍵/值的輸入可以選擇具有三個以上的維度(用於廣播目的)。
ScaledDotProduct 模組將作用於最後三個維度。
其中 L 是目標長度,S 是源長度,H 是注意力頭的數量,N 是批次大小,E 是嵌入維度。