快速連結

sigmoid_focal_loss

torchvision.ops.sigmoid_focal_loss(inputs: Tensor, targets: Tensor, alpha: float = 0.25, gamma: float = 2, reduction: str = 'none') Tensor[原始碼]

RetinaNet 中用於密集偵測的損失函數:https://arxiv.org/abs/1708.02002

參數:
  • inputs (Tensor) – 任意形狀的浮點張量。每個範例的預測值。

  • targets (Tensor) – 形狀與 inputs 相同的浮點張量。儲存 inputs 中每個元素的二元分類標籤(負類別為 0,正類別為 1)。

  • alpha (float) – 範圍 (0,1) 中的權重因子,用於平衡正負範例,若要忽略則設為 -1。預設值:0.25

  • gamma (float) – 調變因子 (1 - p_t) 的指數,用於平衡簡單與困難的範例。預設值:2

  • reduction (string) – 'none' | 'mean' | 'sum' 'none':不會對輸出套用縮減。'mean':輸出將會取平均值。'sum':輸出將會加總。預設值:'none'

回傳:

套用縮減選項的損失張量。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適用於初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源