快捷鍵

RNNTBundle

class torchaudio.pipelines.RNNTBundle[原始碼]

資料類別,將組件捆綁在一起,以使用 RNN-T 模型執行自動語音辨識 (ASR,語音轉文字) 推論。

更具體來說,此類別提供的方法可以產生特徵化管線、包裝指定 RNN-T 模型的解碼器,以及輸出符記後處理器,它們共同構成完整的端對端 ASR 推論管線,該管線可產生給定原始波形的文字序列。

它可以支援非串流(完整上下文)推論以及串流推論。

使用者不應直接實例化此類別的物件;相反地,使用者應使用模組內存在的實例(代表預訓練模型),例如 torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH

範例
>>> import torchaudio
>>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> import torch
>>>
>>> # Non-streaming inference.
>>> # Build feature extractor, decoder with RNN-T model, and token processor.
>>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()
100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s]
>>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt"
100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s]
>>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()
100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s]
>>>
>>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample.
>>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean")
>>> waveform = next(iter(dataset))[0].squeeze()
>>>
>>> with torch.no_grad():
>>>     # Produce mel-scale spectrogram features.
>>>     features, length = feature_extractor(waveform)
>>>
>>>     # Generate top-10 hypotheses.
>>>     hypotheses = decoder(features, length, 10)
>>>
>>> # For top hypothesis, convert predicted tokens to text.
>>> text = token_processor(hypotheses[0][0])
>>> print(text)
he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
>>>
>>>
>>> # Streaming inference.
>>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length
>>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length
>>> num_samples_segment_right_context = (
>>>     num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
>>> )
>>>
>>> # Build streaming inference feature extractor.
>>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor()
>>>
>>> # Process same waveform as before, this time sequentially across overlapping segments
>>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``.
>>> state, hypothesis = None, None
>>> for idx in range(0, len(waveform), num_samples_segment):
>>>     segment = waveform[idx: idx + num_samples_segment_right_context]
>>>     segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
>>>     with torch.no_grad():
>>>         features, length = streaming_feature_extractor(segment)
>>>         hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
>>>     hypothesis = hypotheses[0]
>>>     transcript = token_processor(hypothesis[0])
>>>     if transcript:
>>>         print(transcript, end=" ", flush=True)
he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
使用 RNNTBundle 的教學
Online ASR with Emformer RNN-T

使用 Emformer RNN-T 的線上 ASR

使用 Emformer RNN-T 的線上 ASR
Device ASR with Emformer RNN-T

使用 Emformer RNN-T 的裝置 ASR

使用 Emformer RNN-T 的裝置 ASR

屬性

hop_length

property RNNTBundle.hop_length: int

模型預期的輸入中,連續幀之間的樣本數。

類型:

int

n_fft

property RNNTBundle.n_fft: int

要使用的 FFT 視窗大小。

類型:

int

n_mels

property RNNTBundle.n_mels: int

要從輸入波形提取的 mel 頻譜圖特徵數量。

類型:

int

right_context_length

property RNNTBundle.right_context_length: int

模型預期的輸入中,右側上下文區塊中的幀數。

類型:

int

sample_rate

property RNNTBundle.sample_rate: int

輸入波形的取樣率(每秒週期數)。

類型:

int

segment_length

property RNNTBundle.segment_length: int

模型預期的輸入中,區段中的幀數。

類型:

int

方法

get_decoder

RNNTBundle.get_decoder() RNNTBeamSearch[原始碼]

建構 RNN-T 解碼器。

傳回:

RNNTBeamSearch

get_feature_extractor

RNNTBundle.get_feature_extractor() FeatureExtractor[原始碼]

建構用於非串流(完整上下文)ASR 的特徵提取器。

傳回:

FeatureExtractor

get_streaming_feature_extractor

RNNTBundle.get_streaming_feature_extractor() FeatureExtractor[原始碼]

建構用於串流(同步)ASR 的特徵提取器。

傳回:

FeatureExtractor

get_token_processor

RNNTBundle.get_token_processor() TokenProcessor[原始碼]

建構符記處理器。

傳回:

TokenProcessor

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適用於初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源