torchaudio.functional.rnnt_loss¶
- torchaudio.functional.rnnt_loss(logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor, blank: int = -1, clamp: float = -1, reduction: str = 'mean', fused_log_softmax: bool = True)[原始碼]¶
從使用遞迴神經網路的序列轉導 [Graves, 2012] 計算 RNN Transducer 損失。
RNN Transducer 損失透過定義所有長度的輸出序列的分佈,並聯合建模輸入-輸出和輸出-輸出依賴性,來擴展 CTC 損失。
- 參數:
logits (Tensor) – 維度為 (batch, max seq length, max target length + 1, class) 的 Tensor,包含來自 joiner 的輸出
targets (Tensor) – 維度為 (batch, max target length) 的 Tensor,包含以零填充的目標
logit_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含來自編碼器的每個序列的長度
target_lengths (Tensor) – 維度為 (batch) 的 Tensor,包含每個序列的目標長度
blank (int, optional) – 空白標籤 (預設值:
-1
)clamp (float, optional) – 梯度鉗制 (預設值:
-1
)reduction (string, optional) – 指定要套用至輸出的縮減方式:
"none"
|"mean"
|"sum"
。 (預設值:"mean"
)fused_log_softmax (bool) – 如果在損失函數外部呼叫 log_softmax,則設定為 False (預設值:
True
)
- 返回:
套用縮減選項的損失。如果
reduction
為"none"
,則大小為 (batch),否則為純量。- 返回類型:
Tensor