快捷鍵

KLDivLoss

class torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)[原始碼][原始碼]

Kullback-Leibler 散度損失。

對於相同形狀的張量 ypred, ytruey_{\text{pred}},\ y_{\text{true}},其中 ypredy_{\text{pred}}inputytruey_{\text{true}}target,我們將 逐點 KL 散度 定義為

L(ypred, ytrue)=ytruelogytrueypred=ytrue(logytruelogypred)L(y_{\text{pred}},\ y_{\text{true}}) = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})

為了避免計算此數值時發生下溢問題,此損失函數預期 input 參數位於對數空間中。如果 log_target= True,則 target 參數也可以在對數空間中提供。

總而言之,此函數大致等同於計算

if not log_target: # default
    loss_pointwise = target * (target.log() - input)
else:
    loss_pointwise = target.exp() * (target - input)

然後根據 reduction 參數,對此結果進行縮減,如下所示:

if reduction == "mean":  # default
    loss = loss_pointwise.mean()
elif reduction == "batchmean":  # mathematically correct
    loss = loss_pointwise.sum() / input.size(0)
elif reduction == "sum":
    loss = loss_pointwise.sum()
else:  # reduction == "none"
    loss = loss_pointwise

注意:

如同 PyTorch 中所有其他的損失函數,此函數期望第一個參數 input 為模型的輸出 (例如,神經網路),而第二個參數 target 為資料集中觀察到的值。這與標準數學符號 KL(P  Q)KL(P\ ||\ Q) 不同,其中 PP 表示觀察值的分佈,而 QQ 表示模型。

警告

reduction= “mean” 不會回傳真正的 KL 散度值,請使用 reduction= “batchmean”,這與數學定義一致。

參數
  • size_average (bool, optional) – 已棄用 (請參閱 reduction)。 預設情況下,損失值會對批次中的每個損失元素取平均。 請注意,對於某些損失函數,每個樣本有多個元素。 如果欄位 size_average 設為 False,則會對每個小批次中的損失值進行加總。 當 reduceFalse 時,此設定會被忽略。 預設值:True

  • reduce (bool, optional) – 已棄用 (請參閱 reduction)。 預設情況下,損失值會根據 size_average 的值,對每個小批次的觀察值取平均或加總。 當 reduceFalse 時,會改為回傳每個批次元素的損失值,並忽略 size_average。 預設值:True

  • reduction (str, optional) – 指定要套用到輸出的縮減方式。 預設值:“mean”

  • log_target (bool, optional) – 指定 target 是否在對數空間中。 預設值:False

形狀
  • 輸入: ()(*), 其中 * 表示任意數量的維度。

  • 目標: ()(*), 與輸入相同的形狀。

  • 輸出: 預設為純量。 如果 reduction‘none’,則 ()(*), 與輸入相同的形狀。

範例:
>>> kl_loss = nn.KLDivLoss(reduction="batchmean")
>>> # input should be a distribution in the log space
>>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
>>> # Sample a batch of distributions. Usually this would come from the dataset
>>> target = F.softmax(torch.rand(3, 5), dim=1)
>>> output = kl_loss(input, target)
>>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
>>> log_target = F.log_softmax(torch.rand(3, 5), dim=1)
>>> output = kl_loss(input, log_target)

文件

取得 PyTorch 的完整開發人員文件

查看文件

教學

取得針對初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源