torch.nn.functional.embedding¶
- torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)[原始碼][原始碼]¶
產生一個簡單的查詢表,用於在固定字典和大小中查找 embeddings。
此模組通常用於使用索引檢索詞嵌入。該模組的輸入是索引列表和嵌入矩陣,輸出是相應的詞嵌入。
詳情請參閱
torch.nn.Embedding
。注意
請注意,對於
padding_idx
指定的行中,此函數關於weight
條目的解析梯度,預期會與數值梯度不同。注意
請注意,:class:`torch.nn.Embedding 與此函數的不同之處在於,它在建構時會將
padding_idx
指定的weight
行初始化為全零。- 參數
input (LongTensor) – 包含嵌入矩陣索引的 Tensor
weight (Tensor) – 嵌入矩陣,其行數等於最大可能索引 + 1,列數等於嵌入大小
padding_idx (int, optional) – 如果指定,
padding_idx
處的條目不會對梯度產生影響;因此,padding_idx
處的嵌入向量在訓練期間不會更新,即保持為固定的“pad”。max_norm (float, optional) – 如果給定,則每個範數大於
max_norm
的嵌入向量,將被重新正規化為具有範數max_norm
。注意:這將會就地修改weight
。norm_type (float, optional) – 用於計算
max_norm
選項的 p 範數的 p 值。預設值為2
。scale_grad_by_freq (bool, optional) – 如果給定,這將按 mini-batch 中單詞頻率的倒數來縮放梯度。預設值為
False
。sparse (bool, optional) – 如果為
True
,則關於weight
的梯度將是一個稀疏 tensor。 有關稀疏梯度的更多詳細資訊,請參閱torch.nn.Embedding
下的 Notes。
- 回傳類型
- 形狀
Input: 包含要提取索引的任意形狀的 LongTensor
Weight: 浮點數類型的嵌入矩陣,形狀為 (V, embedding_dim),其中 V = 最大索引 + 1,embedding_dim = 嵌入大小
Output: (*, embedding_dim),其中 * 是輸入形狀
範例
>>> # a batch of 2 samples of 4 indices each >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) >>> # an embedding matrix containing 10 tensors of size 3 >>> embedding_matrix = torch.rand(10, 3) >>> F.embedding(input, embedding_matrix) tensor([[[ 0.8490, 0.9625, 0.6753], [ 0.9666, 0.7761, 0.6108], [ 0.6246, 0.9751, 0.3618], [ 0.4161, 0.2419, 0.7383]], [[ 0.6246, 0.9751, 0.3618], [ 0.0237, 0.7794, 0.0528], [ 0.9666, 0.7761, 0.6108], [ 0.3385, 0.8612, 0.1867]]]) >>> # example with padding_idx >>> weights = torch.rand(10, 3) >>> weights[0, :].zero_() >>> embedding_matrix = weights >>> input = torch.tensor([[0, 2, 0, 5]]) >>> F.embedding(input, embedding_matrix, padding_idx=0) tensor([[[ 0.0000, 0.0000, 0.0000], [ 0.5609, 0.5384, 0.8720], [ 0.0000, 0.0000, 0.0000], [ 0.6262, 0.2438, 0.7471]]])