RNNT¶
- class torchaudio.models.RNNT[原始碼]¶
遞歸神經網路轉換器 (RNN-T) 模型。
注意
要建置模型,請使用其中一個工廠函數。
另請參閱
torchaudio.pipelines.RNNTBundle
: 具有預訓練模型的 ASR pipeline。- 參數:
transcriber (torch.nn.Module) – 轉錄網路。
predictor (torch.nn.Module) – 預測網路。
joiner (torch.nn.Module) – 聯合網路。
方法¶
forward¶
- RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]] [原始碼]¶
用於訓練的前向傳遞。
B:批次大小;T:批次中最大來源序列長度;U:批次中最大目標序列長度;D:每個來源序列元素的功能維度。
- 參數:
sources (torch.Tensor) – 來源幀序列,以右側上下文進行右側填充,形狀為 (B, T, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
sources
中第 i 個批次元素的有效幀數。targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素映射到一個目標符號。
target_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
targets
中第 i 個批次元素的有效幀數。predictor_state (List[List[torch.Tensor]] or None, optional) – 列表的列表,表示在先前調用
forward
中生成的預測網路內部狀態。(預設值:None
)
- 返回:
- torch.Tensor
聯合網路輸出,形狀為 (B, 最大輸出來源長度, 最大輸出目標長度, output_dim (目標符號數量))。
- torch.Tensor
輸出來源長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 1 的有效元素數量。
- torch.Tensor
輸出目標長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 2 的有效元素數量。
- List[List[torch.Tensor]]
輸出狀態;列表的列表,表示在目前調用
forward
中生成的預測網路內部狀態。
- 返回類型:
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe_streaming¶
- RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [原始碼]¶
以串流模式將轉錄網路應用於來源。
B:批次大小;T:批次中最大來源序列片段長度;D:每個來源序列幀的功能維度。
- 參數:
sources (torch.Tensor) – 來源幀序列片段,以右側上下文進行右側填充,形狀為 (B, T + 右側上下文長度, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
sources
中第 i 個批次元素的有效幀數。state (List[List[torch.Tensor]] or None) – 列表的列表,表示在先前調用
transcribe_streaming
中生成的轉錄網路內部狀態。
- 返回:
- torch.Tensor
輸出幀序列,形狀為 (B, T // time_reduction_stride, output_dim)。
- torch.Tensor
輸出長度,形狀為 (B,),第 i 個元素表示輸出中第 i 個批次元素的有效元素數量。
- List[List[torch.Tensor]]
輸出狀態;列表的列表,表示在目前調用
transcribe_streaming
中生成的轉錄網路內部狀態。
- 返回類型:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe¶
- RNNT.transcribe(sources: torch.Tensor, source_lengths: torch.Tensor) Tuple[Tensor, Tensor] [原始碼]¶
以非串流模式將轉錄網路應用於來源。
B:批次大小;T:批次中最大來源序列長度;D:每個來源序列幀的功能維度。
- 參數:
sources (torch.Tensor) – 來源幀序列,以右側上下文進行右側填充,形狀為 (B, T + 右側上下文長度, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
sources
中第 i 個批次元素的有效幀數。
- 返回:
- torch.Tensor
輸出幀序列,形狀為 (B, T // time_reduction_stride, output_dim)。
- torch.Tensor
輸出長度,形狀為 (B,),第 i 個元素表示輸出幀序列中第 i 個批次元素的有效元素數量。
- 返回類型:
predict¶
- RNNT.predict(targets: torch.Tensor, target_lengths: torch.Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [原始碼]¶
將預測網路應用於目標。
B:批次大小;U:批次中最大目標序列長度;D:每個目標序列幀的功能維度。
- 參數:
targets (torch.Tensor) – 目標序列,形狀為 (B, U),每個元素映射到一個目標符號,即在範圍 [0, num_symbols) 內。
target_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
targets
中第 i 個批次元素的有效幀數。state (List[List[torch.Tensor]] or None) – 列表的列表,表示在先前調用
predict
中生成的內部狀態。
- 返回:
- torch.Tensor
輸出幀序列,形狀為 (B, U, output_dim)。
- torch.Tensor
輸出長度,形狀為 (B,),第 i 個元素表示輸出中第 i 個批次元素的有效元素數量。
- List[List[torch.Tensor]]
輸出狀態;列表的列表,表示在目前調用
predict
中生成的內部狀態。
- 返回類型:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
join¶
- RNNT.join(source_encodings: torch.Tensor, source_lengths: torch.Tensor, target_encodings: torch.Tensor, target_lengths: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor] [原始碼]¶
將聯合網路應用於來源和目標編碼。
B:批次大小;T:批次中最大來源序列長度;U:批次中最大目標序列長度;D:每個來源和目標序列編碼的維度。
- 參數:
source_encodings (torch.Tensor) – 來源編碼序列,形狀為 (B, T, D)。
source_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
source_encodings
中第 i 個批次元素的有效序列長度。target_encodings (torch.Tensor) – 目標編碼序列,形狀為 (B, U, D)。
target_lengths (torch.Tensor) – 形狀為 (B,),第 i 個元素表示
target_encodings
中第 i 個批次元素的有效序列長度。
- 返回:
- torch.Tensor
聯合網路輸出,形狀為 (B, T, U, output_dim)。
- torch.Tensor
輸出來源長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 1 的有效元素數量。
- torch.Tensor
輸出目標長度,形狀為 (B,),第 i 個元素表示聯合網路輸出中第 i 個批次元素沿維度 2 的有效元素數量。
- 返回類型:
工廠函數¶
建置基於 Emformer 的 |
|
建置基於 Emformer 的 |
Prototype 工廠函數¶
建置基於 Conformer 的遞歸神經網路轉換器 (RNN-T) 模型。 |
|
建置 Conformer RNN-T 模型的基本版本。 |