Wav2Vec2FABundle¶
- class torchaudio.pipelines.Wav2Vec2FABundle[原始碼]¶
資料類別,捆綁相關資訊以使用預訓練的
Wav2Vec2Model
進行強制對齊。此類別提供介面,用於實例化預訓練模型,以及檢索預訓練權重和與模型一起使用的其他資料所需的資訊。
Torchaudio 函式庫實例化此類別的物件,每個物件代表不同的預訓練模型。用戶端程式碼應透過這些實例存取預訓練模型。
請參閱下方的使用方法和可用值。
- 範例 - 特徵提取
>>> import torchaudio >>> >>> bundle = torchaudio.pipelines.MMS_FA >>> >>> # Build the model and load pretrained weight. >>> model = bundle.get_model() Downloading: 100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s] >>> >>> # Resample audio to the expected sampling rate >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) >>> >>> # Estimate the probability of token distribution >>> emission, _ = model(waveform) >>> >>> # Generate frame-wise alignment >>> alignment, scores = torchaudio.functional.forced_align( >>> emission, targets, input_lengths, target_lengths, blank=0) >>>
- 使用
Wav2Vec2FABundle
的教學
屬性¶
sample_rate¶
方法¶
get_aligner¶
get_dict¶
- Wav2Vec2FABundle.get_dict(star: Optional[str] = '*', blank: str = '-') Dict[str, int] [原始碼]¶
取得從符記到索引的映射 (在發射特徵維度中)
- 參數:
- 傳回:
對於在 ASR 上微調的模型,傳回代表輸出類別標籤的字串元組。
- 傳回類型:
Tuple[str, …]
- 範例
>>> from torchaudio.pipelines import MMS_FA as bundle >>> bundle.get_dict() {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28} >>> bundle.get_dict(star=None) {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
get_labels¶
- Wav2Vec2FABundle.get_labels(star: Optional[str] = '*', blank: str = '-') Tuple[str, ...] [原始碼]¶
取得對應於發射特徵維度的標籤。
第一個是空白符記,且可自訂。
- 參數:
- 傳回:
對於在 ASR 上微調的模型,傳回代表輸出類別標籤的字串元組。
- 傳回類型:
Tuple[str, …]
- 範例
>>> from torchaudio.pipelines import MMS_FA as bundle >>> bundle.get_labels() ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*') >>> bundle.get_labels(star=None) ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
get_model¶
- Wav2Vec2FABundle.get_model(with_star: bool = True, *, dl_kwargs=None) Module [原始碼]¶
建構模型並載入預訓練權重。
權重檔案從網際網路下載並使用
torch.hub.load_state_dict_from_url()
進行快取- 參數:
with_star (bool, 可選) – 如果啟用,則輸出層的最後一個維度會擴展一個,對應於 星號 符記。
dl_kwargs (關鍵字引數字典) – 傳遞至
torch.hub.load_state_dict_from_url()
。
- 傳回:
Wav2Vec2Model
的變體。注意
使用此方法建立的模型會傳回對數域中的機率 (即套用
torch.nn.functional.log_softmax()
),而其他 Wav2Vec2 模型則傳回 logits。