捷徑

ConvEmformer

class torchaudio.prototype.models.ConvEmformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, kernel_size: int, dropout: float = 0.0, ffn_activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = -100000000.0, conv_activation: str = 'silu')[原始碼]

實作了卷積增強串流轉換器架構,該架構在使用非因果卷積的串流轉換器轉導器語音辨識 [Shi et al., 2022] 中被介紹。

參數:
  • input_dim (int) – 輸入維度。

  • num_heads (int) – 每個 ConvEmformer 層中的注意力頭數。

  • ffn_dim (int) – 每個 ConvEmformer 層的前饋網路的隱藏層維度。

  • num_layers (int) – 要實例化的 ConvEmformer 層數。

  • segment_length (int) – 每個輸入片段的長度。

  • kernel_size (int) – 卷積模組中使用的核心大小。

  • dropout (float, 選用) – dropout 機率。(預設值:0.0)

  • ffn_activation (str, 選用) – 在前饋網路中使用的激活函數。必須是 (“relu”, “gelu”, “silu”) 之一。(預設值:“relu”)

  • left_context_length (int, 選用) – 左側上下文的長度。(預設值:0)

  • right_context_length (int, 選用) – 右側上下文的長度。(預設值:0)

  • max_memory_size (int, 選用) – 要使用的最大記憶體元素數量。(預設值:0)

  • weight_init_scale_strategy (strNone, 選用) – 每層權重初始化縮放策略。必須是 (“depthwise”, “constant”, None) 之一。(預設值:“depthwise”)

  • tanh_on_mem (bool, 選用) – 如果 True,則對記憶體元素應用 tanh。(預設值:False)

  • negative_inf (float, 選用) – 在注意力權重中用於負無限大的值。(預設值:-1e8)

  • conv_activation (str, 選用) – 在卷積模組中使用的激活函數。必須是 (“relu”, “gelu”, “silu”) 之一。(預設值:“silu”)

範例

>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4)
>>> input = torch.rand(10, 200, 80)
>>> lengths = torch.randint(1, 200, (10,))
>>> output, lengths = conv_emformer(input, lengths)
>>> input = torch.rand(4, 20, 80)
>>> lengths = torch.ones(4) * 20
>>> output, lengths, states = conv_emformer.infer(input, lengths, None)

方法

forward

ConvEmformer.forward(input: Tensor, lengths: Tensor) Tuple[Tensor, Tensor]

用於訓練和非串流推論的 Forward 傳遞。

B:批次大小;T:批次中輸入影格的最大數量;D:每個影格的特徵維度。

參數:
  • input (torch.Tensor) – 使用右側上下文影格右側填充的話語影格,形狀為 (B, T + right_context_length, D)

  • lengths (torch.Tensor) – 形狀為 (B,),且第 i 個元素表示 input 中第 i 個批次元素的有效話語影格數。

傳回:

Tensor

輸出影格,形狀為 (B, T, D)

Tensor

輸出長度,形狀為 (B,),且第 i 個元素表示輸出影格中第 i 個批次元素的有效影格數。

傳回類型:

(Tensor, Tensor)

infer

ConvEmformer.infer(input: Tensor, lengths: Tensor, states: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, List[List">[Tensor]]]

用於串流推論的 Forward 傳遞。

B:批次大小;D:每個影格的特徵維度。

參數:
  • input (torch.Tensor) – 使用右側上下文影格右側填充的話語影格,形狀為 (B, segment_length + right_context_length, D)

  • lengths (torch.Tensor) – 形狀為 (B,),且第 i 個元素表示 input 中第 i 個批次元素的有效影格數。

  • states (List[List[torch.Tensor]] 或 None, 選用) – 張量列表的列表,表示在先前調用 infer 中生成的內部狀態。(預設值:None)

傳回:

Tensor

輸出影格,形狀為 (B, segment_length, D)

Tensor

輸出長度,形狀為 (B,),且第 i 個元素表示輸出影格中第 i 個批次元素的有效影格數。

List[List[Tensor]]

輸出狀態;張量列表的列表,表示在目前調用 infer 中生成的內部狀態。

傳回類型:

(Tensor, Tensor, List[List[Tensor]])

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源