GaussianNLLLoss¶
- class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[原始碼][原始碼]¶
高斯負對數概似損失。
目標被視為來自高斯分佈的樣本,其期望值和變異數由神經網路預測。對於以期望值張量
input
和正變異數張量var
建模為具有高斯分佈的target
張量,損失為其中
eps
用於穩定性。預設情況下,除非full
為True
,否則會省略損失函數的常數項。如果var
的大小與input
不同(由於同方差假設),則它必須具有為 1 的最終維度,或者少一個維度(所有其他大小相同)才能正確廣播。- 參數
- 形狀
輸入: 或 ,其中 表示任意數量的額外維度
目標: 或 ,與輸入相同的形狀,或與輸入相同的形狀但其中一個維度等於 1(以允許廣播)
Var: 或 ,與輸入形狀相同,或與輸入形狀相同但其中一個維度等於 1,或與輸入形狀相同但維度少一個(以允許廣播),或一個純量值
輸出:如果
reduction
為'mean'
(預設) 或'sum'
,則為純量。如果reduction
為'none'
,則為 ,與輸入形狀相同
- 範例:
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic >>> output = loss(input, target, var) >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic >>> output = loss(input, target, var) >>> output.backward()
注意
對
var
的鉗制會被 autograd 忽略,因此梯度不受其影響。- 參考文獻
Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138.