快捷鍵

RNNT

class torchaudio.models.RNNT[原始碼]

遞歸神經網路轉換器 (RNN-T) 模型。

注意

要建置模型,請使用其中一個工廠函數。

另請參閱

torchaudio.pipelines.RNNTBundle: 具有預訓練模型的 ASR pipeline。

參數:

方法

forward

RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]][原始碼]

用於訓練的前向傳遞。

B:批次大小;T:批次中最大來源序列長度;U:批次中最大目標序列長度;D:每個來源序列元素的功能維度。

參數:
  • sources (torch.Tensor) – 來源幀序列,以右側上下文進行右側填充,形狀為 (B, T, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 sources 中第 i 個批次元素的有效幀數。

  • targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素映射到一個目標符號。

  • target_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 targets 中第 i 個批次元素的有效幀數。

  • predictor_state (List[List[torch.Tensor]] or None, optional) – 列表的列表,表示在先前調用 forward 中生成的預測網路內部狀態。(預設值:None

返回:

torch.Tensor

聯合網路輸出,形狀為 (B, 最大輸出來源長度, 最大輸出目標長度, output_dim (目標符號數量))

torch.Tensor

輸出來源長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 1 的有效元素數量。

torch.Tensor

輸出目標長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 2 的有效元素數量。

List[List[torch.Tensor]]

輸出狀態;列表的列表,表示在目前調用 forward 中生成的預測網路內部狀態。

返回類型:

(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe_streaming

RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][原始碼]

以串流模式將轉錄網路應用於來源。

B:批次大小;T:批次中最大來源序列片段長度;D:每個來源序列幀的功能維度。

參數:
  • sources (torch.Tensor) – 來源幀序列片段,以右側上下文進行右側填充,形狀為 (B, T + 右側上下文長度, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 sources 中第 i 個批次元素的有效幀數。

  • state (List[List[torch.Tensor]] or None) – 列表的列表,表示在先前調用 transcribe_streaming 中生成的轉錄網路內部狀態。

返回:

torch.Tensor

輸出幀序列,形狀為 (B, T // time_reduction_stride, output_dim)

torch.Tensor

輸出長度,形狀為 (B,),第 i 個元素表示輸出中第 i 個批次元素的有效元素數量。

List[List[torch.Tensor]]

輸出狀態;列表的列表,表示在目前調用 transcribe_streaming 中生成的轉錄網路內部狀態。

返回類型:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe

RNNT.transcribe(sources: torch.Tensor, source_lengths: torch.Tensor) Tuple[Tensor, Tensor][原始碼]

以非串流模式將轉錄網路應用於來源。

B:批次大小;T:批次中最大來源序列長度;D:每個來源序列幀的功能維度。

參數:
  • sources (torch.Tensor) – 來源幀序列,以右側上下文進行右側填充,形狀為 (B, T + 右側上下文長度, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 sources 中第 i 個批次元素的有效幀數。

返回:

torch.Tensor

輸出幀序列,形狀為 (B, T // time_reduction_stride, output_dim)

torch.Tensor

輸出長度,形狀為 (B,),第 i 個元素表示輸出幀序列中第 i 個批次元素的有效元素數量。

返回類型:

(torch.Tensor, torch.Tensor)

predict

RNNT.predict(targets: torch.Tensor, target_lengths: torch.Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][原始碼]

將預測網路應用於目標。

B:批次大小;U:批次中最大目標序列長度;D:每個目標序列幀的功能維度。

參數:
  • targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素映射到一個目標符號,即在範圍 [0, num_symbols) 內。

  • target_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 targets 中第 i 個批次元素的有效幀數。

  • state (List[List[torch.Tensor]] or None) – 列表的列表,表示在先前調用 predict 中生成的內部狀態。

返回:

torch.Tensor

輸出幀序列,形狀為 (B, U, output_dim)

torch.Tensor

輸出長度,形狀為 (B,),第 i 個元素表示輸出中第 i 個批次元素的有效元素數量。

List[List[torch.Tensor]]

輸出狀態;列表的列表,表示在目前調用 predict 中生成的內部狀態。

返回類型:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

join

RNNT.join(source_encodings: torch.Tensor, source_lengths: torch.Tensor, target_encodings: torch.Tensor, target_lengths: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][原始碼]

將聯合網路應用於來源和目標編碼。

B:批次大小;T:批次中最大來源序列長度;U:批次中最大目標序列長度;D:每個來源和目標序列編碼的維度。

參數:
  • source_encodings (torch.Tensor) – 來源編碼序列,形狀為 (B, T, D)

  • source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 source_encodings 中第 i 個批次元素的有效序列長度。

  • target_encodings (torch.Tensor) – 目標編碼序列,形狀為 (B, U, D)

  • target_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示 target_encodings 中第 i 個批次元素的有效序列長度。

返回:

torch.Tensor

聯合網路輸出,形狀為 (B, T, U, output_dim)

torch.Tensor

輸出來源長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 1 的有效元素數量。

torch.Tensor

輸出目標長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 2 的有效元素數量。

返回類型:

(torch.Tensor, torch.Tensor, torch.Tensor)

工廠函數

emformer_rnnt_model

建置基於 Emformer 的 RNNT

emformer_rnnt_base

建置基於 Emformer 的 RNNT 的基本版本。

Prototype 工廠函數

conformer_rnnt_model

建置基於 Conformer 的遞歸神經網路轉換器 (RNN-T) 模型。

conformer_rnnt_base

建置 Conformer RNN-T 模型的基本版本。

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源