快速鍵

GaussianNLLLoss

class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[原始碼][原始碼]

高斯負對數概似損失。

目標被視為來自高斯分佈的樣本,其期望值和變異數由神經網路預測。對於以期望值張量 input 和正變異數張量 var 建模為具有高斯分佈的 target 張量,損失為

loss=12(log(max(var, eps))+(inputtarget)2max(var, eps))+const.\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

其中 eps 用於穩定性。預設情況下,除非 fullTrue,否則會省略損失函數的常數項。如果 var 的大小與 input 不同(由於同方差假設),則它必須具有為 1 的最終維度,或者少一個維度(所有其他大小相同)才能正確廣播。

參數
  • full (bool, optional) – 在損失計算中包含常數項。預設值:False

  • eps (float, optional) – 用於鉗制 var 的值(請參閱下面的註釋),以實現穩定性。預設值:1e-6。

  • reduction (str, optional) – 指定要應用於輸出的縮減方式:'none' | 'mean' | 'sum''none':不應用縮減,'mean':輸出是所有批次成員損失的平均值,'sum':輸出是所有批次成員損失的總和。預設值:'mean'

形狀
  • 輸入:(N,)(N, *)()(*),其中 * 表示任意數量的額外維度

  • 目標:(N,)(N, *)()(*),與輸入相同的形狀,或與輸入相同的形狀但其中一個維度等於 1(以允許廣播)

  • Var:(N,)(N, *)()(*),與輸入形狀相同,或與輸入形狀相同但其中一個維度等於 1,或與輸入形狀相同但維度少一個(以允許廣播),或一個純量值

  • 輸出:如果 reduction'mean' (預設) 或 'sum',則為純量。如果 reduction'none',則為 (N,)(N, *),與輸入形狀相同

範例:
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

注意

var 的鉗制會被 autograd 忽略,因此梯度不受其影響。

參考文獻

Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138.

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學課程

取得初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源