torch.nn.functional.ctc_loss¶
- torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[來源][來源]¶
應用 Connectionist Temporal Classification 損失。
有關詳細資訊,請參閱
CTCLoss
。注意
在某些情況下,當在 CUDA 裝置上提供張量並使用 CuDNN 時,此運算子可能會選擇一個非確定性的演算法來提高效能。如果這不是您想要的,您可以嘗試通過設定
torch.backends.cudnn.deterministic = True
使操作具有確定性(可能會犧牲效能)。更多資訊請參閱 再現性 (Reproducibility)。注意
當在 CUDA 裝置上提供張量時,此操作可能會產生非確定性的梯度。更多資訊請參閱 再現性 (Reproducibility)。
- 參數
log_probs (Tensor) – 或 ,其中 C = 字母表中字元的數量(包括空白),T = 輸入長度,以及 N = 批次大小。 輸出的對數機率(例如,使用
torch.nn.functional.log_softmax()
獲得)。targets (Tensor) – 或 (sum(target_lengths))。 目標不能為空白。 在第二種形式中,目標假設為串聯的。
input_lengths (Tensor) – 或 。 輸入的長度(必須都 )
target_lengths (Tensor) – 或 。 目標的長度
blank (int, optional) – 空白標籤。 預設值 。
reduction (str, optional) – 指定應用於輸出的縮減方式:
'none'
|'mean'
|'sum'
。'none'
:不應用縮減,'mean'
:輸出損失將除以目標長度,然後取批次的平均值,'sum'
:輸出將被加總。 預設值:'mean'
zero_infinity (bool, optional) – 是否將無窮大的損失和相關的梯度歸零。 預設值:
False
當輸入太短而無法與目標對齊時,主要會發生無窮大的損失。
- 回傳型別
範例
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward()