快速鍵

CTCLoss

class torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[source][source]

連接主義時間分類 (Connectionist Temporal Classification) 損失函數。

計算連續(未分割)時間序列和目標序列之間的損失。CTCLoss 對輸入與目標的所有可能對齊方式的機率求和,產生一個可針對每個輸入節點微分的損失值。假設輸入與目標的對齊方式是「多對一」,這限制了目標序列的長度,使其必須 \leq 輸入長度。

參數
  • blank (int, optional) – 空白標籤。預設值 00

  • reduction (str, optional) – 指定應用於輸出的 reduction 方式: 'none' | 'mean' | 'sum''none':不執行 reduction, 'mean':輸出損失將除以目標長度,然後取批次的平均值, 'sum':輸出損失將被加總。預設值: 'mean'

  • zero_infinity (bool, optional) – 是否將無限大的損失和相關梯度歸零。預設值: False。當輸入太短而無法與目標對齊時,通常會發生無限大的損失。

形狀
  • Log_probs: 大小為 (T,N,C)(T, N, C)(T,C)(T, C) 的 Tensor,其中 T=輸入長度T = \text{輸入長度}, N=批次大小N = \text{批次大小}, 且 C=類別數量(包含空白)C = \text{類別數量(包含空白)}。 輸出(例如使用 torch.nn.functional.log_softmax() 獲得)的對數機率。

  • 目標:大小為 (N,S)(N, S)(sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 的張量,其中 N=batch sizeN = \text{batch size}S=max target length, if shape is (N,S)S = \text{max target length, if shape is } (N, S)。 它代表目標序列。 目標序列中的每個元素都是一個類別索引。 並且目標索引不能為空(預設值=0)。 在 (N,S)(N, S) 形式中,目標被填充到最長序列的長度,並堆疊。 在 (sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 形式中,目標被假定為未填充並在 1 個維度內連接。

  • Input_lengths:大小為 (N)(N)()() 的元組或張量,其中 N=batch sizeN = \text{batch size}。 它表示輸入的長度(每個長度都必須 T\leq T)。 並且為每個序列指定長度,以實現遮罩,假設序列被填充到相等的長度。

  • Target_lengths:大小為 (N)(N)()() 的 Tuple 或 tensor,其中 N=batch sizeN = \text{batch size}。它表示目標的長度。 為每個序列指定長度,以在假設序列填充到相等長度的情況下實現遮罩。如果目標形狀為 (N,S)(N,S),則 target_lengths 實際上是每個目標序列的停止索引 sns_n,因此對於批次中的每個目標, target_n = targets[n,0:s_n]。長度必須各自 S\leq S。如果目標以一維 tensor 形式給出,該 tensor 是個別目標的串聯,則 target_lengths 的總和必須等於 tensor 的總長度。

  • 輸出:如果 reduction'mean' (預設) 或 'sum',則為純量。如果 reduction'none',則如果輸入為批次處理,則為 (N)(N),如果輸入未批次處理,則為 ()(),其中 N=batch sizeN = \text{batch size}

範例

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded and unbatched (effectively N=1)
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>>
>>> # Initialize random batch of input vectors, for *size = (T,C)
>>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
>>> input_lengths = torch.tensor(T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
參考文獻

A. Graves et al.: Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks: https://www.cs.toronto.edu/~graves/icml_2006.pdf

注意

為了使用 CuDNN,必須滿足以下條件: targets 必須採用串聯格式,所有 input_lengths 必須為 Tblank=0blank=0, target_lengths 256\leq 256,整數參數的 dtype 必須為 torch.int32

常規實作使用(在 PyTorch 中更常見的) torch.long dtype。

注意

在某些情況下,當使用帶有 CuDNN 的 CUDA 後端時,此運算符可能會選擇非決定性演算法來提高效能。如果這是不可取的,您可以嘗試通過設定 torch.backends.cudnn.deterministic = True 來使操作具有確定性(可能會犧牲效能)。請參閱有關 再現性 的注意事項以了解背景資訊。

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources