CTCLoss¶
- class torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[source][source]¶
連接主義時間分類 (Connectionist Temporal Classification) 損失函數。
計算連續(未分割)時間序列和目標序列之間的損失。CTCLoss 對輸入與目標的所有可能對齊方式的機率求和,產生一個可針對每個輸入節點微分的損失值。假設輸入與目標的對齊方式是「多對一」,這限制了目標序列的長度,使其必須 輸入長度。
- 參數
- 形狀
Log_probs: 大小為 或 的 Tensor,其中 , , 且 。 輸出(例如使用
torch.nn.functional.log_softmax()
獲得)的對數機率。目標:大小為 或 的張量,其中 且 。 它代表目標序列。 目標序列中的每個元素都是一個類別索引。 並且目標索引不能為空(預設值=0)。 在 形式中,目標被填充到最長序列的長度,並堆疊。 在 形式中,目標被假定為未填充並在 1 個維度內連接。
Input_lengths:大小為 或 的元組或張量,其中 。 它表示輸入的長度(每個長度都必須 )。 並且為每個序列指定長度,以實現遮罩,假設序列被填充到相等的長度。
Target_lengths:大小為 或 的 Tuple 或 tensor,其中 。它表示目標的長度。 為每個序列指定長度,以在假設序列填充到相等長度的情況下實現遮罩。如果目標形狀為 ,則 target_lengths 實際上是每個目標序列的停止索引 ,因此對於批次中的每個目標,
target_n = targets[n,0:s_n]
。長度必須各自 。如果目標以一維 tensor 形式給出,該 tensor 是個別目標的串聯,則 target_lengths 的總和必須等於 tensor 的總長度。輸出:如果
reduction
為'mean'
(預設) 或'sum'
,則為純量。如果reduction
為'none'
,則如果輸入為批次處理,則為 ,如果輸入未批次處理,則為 ,其中 。
範例
>>> # Target are to be padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded and unbatched (effectively N=1) >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> >>> # Initialize random batch of input vectors, for *size = (T,C) >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() >>> input_lengths = torch.tensor(T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
- 參考文獻
A. Graves et al.: Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks: https://www.cs.toronto.edu/~graves/icml_2006.pdf
注意
為了使用 CuDNN,必須滿足以下條件:
targets
必須採用串聯格式,所有input_lengths
必須為 T。,target_lengths
,整數參數的 dtype 必須為torch.int32
。常規實作使用(在 PyTorch 中更常見的) torch.long dtype。
注意
在某些情況下,當使用帶有 CuDNN 的 CUDA 後端時,此運算符可能會選擇非決定性演算法來提高效能。如果這是不可取的,您可以嘗試通過設定
torch.backends.cudnn.deterministic = True
來使操作具有確定性(可能會犧牲效能)。請參閱有關 再現性 的注意事項以了解背景資訊。