捷徑

AdaptiveLogSoftmaxWithLoss

class torch.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, device=None, dtype=None)[原始碼][原始碼]

有效的 softmax 近似。

Edouard Grave、Armand Joulin、Moustapha Cissé、David Grangier 和 Hervé Jégou 撰寫的適用於 GPU 的有效 softmax 近似中所述。

Adaptive softmax 是一種近似策略,用於訓練具有大型輸出空間的模型。當標籤分佈高度不平衡時,它最有效,例如在自然語言建模中,詞頻分佈近似於 Zipf 定律

Adaptive softmax 根據標籤的頻率將標籤劃分為幾個群集。這些群集可能包含不同數量的目標。此外,包含較不頻繁標籤的群集會為這些標籤分配較低維度的嵌入,從而加快計算速度。對於每個小批量,僅評估至少存在一個目標的群集。

這個概念是,經常訪問的群集(例如第一個群集,包含最頻繁的標籤),應該也易於計算,也就是說,包含少量分配的標籤。

我們強烈建議您查看原始論文以了解更多詳細資訊。

  • cutoffs 應該是一個按遞增順序排序的整數序列。它控制群集的數量以及目標到群集的劃分。例如,設定 cutoffs = [10, 100, 1000] 表示前 10 個目標將分配給 adaptive softmax 的「head」,目標 11, 12, …, 100 將分配給第一個群集,而目標 101, 102, …, 1000 將分配給第二個群集,而目標 1001, 1002, …, n_classes - 1 將分配給最後一個,也就是第三個群集。

  • div_value 用於計算每個額外群集的大小,計算公式為 in_featuresdiv_valueidx\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor,其中 idxidx 是群集索引(較不頻繁單詞的群集具有較大的索引,且索引從 11 開始)。

  • 如果設定為 True,head_bias 會將偏差項新增到 adaptive softmax 的「head」。詳情請參閱論文。在官方實現中設定為 False。

警告

作為輸入傳遞到此模組的標籤應根據其頻率排序。這表示最頻繁的標籤應由索引 0 表示,而最不頻繁的標籤應由索引 n_classes - 1 表示。

注意

此模組傳回具有 outputloss 欄位的 NamedTuple。有關詳細資訊,請參閱進一步的文件。

注意

若要計算所有類別的對數機率,可以使用 log_prob 方法。

參數
  • in_features (int) – 輸入張量中的特徵數量

  • n_classes (int) – 資料集中的類別數量

  • cutoffs (Sequence) – 用於將目標分配到其桶的 cutoff

  • div_value (float, optional) – 用作指數以計算群集大小的值。預設值:4.0

  • head_bias (bool, optional) – 如果 True,則將偏差項新增到 adaptive softmax 的「head」。預設值:False

傳回

  • output 是一個大小為 N 的張量,其中包含每個範例的計算目標對數機率

  • loss 是一個純量,代表計算的負對數概似損失

傳回類型

具有 outputloss 欄位的 NamedTuple

形狀
  • input: (N,in_features)(N, \texttt{in\_features})(in_features)(\texttt{in\_features})

  • 目標:(N)(N)()(),其中每個值都滿足 0<=target[i]<=n_classes0 <= \texttt{target[i]} <= \texttt{n\_classes}

  • 輸出 1:(N)(N)()()

  • 輸出 2:Scalar

log_prob(input)[source][source]

計算所有 n_classes\texttt{n\_classes} 的對數機率。

參數

input (Tensor) – 一個小批次的範例

傳回

每個類別 cc 在範圍 0<=c<=n_classes0 <= c <= \texttt{n\_classes} 的對數機率,其中 n_classes\texttt{n\_classes} 是傳遞給 AdaptiveLogSoftmaxWithLoss 建構子的參數。

傳回類型

Tensor

形狀
  • 輸入:(N,in_features)(N, \texttt{in\_features})

  • 輸出:(N,n_classes)(N, \texttt{n\_classes})

predict(input)[source][source]

回傳輸入小批次中每個範例具有最高機率的類別。

這相當於 self.log_prob(input).argmax(dim=1),但在某些情況下效率更高。

參數

input (Tensor) – 一個小批次的範例

傳回

每個範例具有最高機率的類別

傳回類型

輸出 (Tensor)

形狀
  • 輸入:(N,in_features)(N, \texttt{in\_features})

  • 輸出:(N)(N)

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您的問題解答

檢視資源