捷徑

TransformerDecoder

class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[原始碼][原始碼]

TransformerDecoder 是 N 個解碼器層的堆疊。

注意

請參閱本教學課程,以深入討論 PyTorch 為您建構自己的 Transformer 層所提供的效能建構模組。

參數
  • decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 類別的實例 (必要)。

  • num_layers (int) – 解碼器中子解碼器層的數量 (必要)。

  • norm (Optional[Module]) – 層正規化元件 (可選)。

範例:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[source][source]

依序將輸入(和遮罩)傳遞到解碼器層。

參數
  • tgt (Tensor) – 要傳給解碼器的序列 (必要)。

  • memory (Tensor) – 來自編碼器最後一層的序列 (必要)。

  • tgt_mask (Optional[Tensor]) – tgt 序列的遮罩 (可選)。

  • memory_mask (Optional[Tensor]) – memory 序列的遮罩 (可選)。

  • tgt_key_padding_mask (Optional[Tensor]) – 每個批次 tgt 鍵的遮罩 (可選)。

  • memory_key_padding_mask (Optional[Tensor]) – 每個批次 memory 鍵的遮罩 (可選)。

  • tgt_is_causal (Optional[bool]) – 如果指定,則將因果遮罩應用為 tgt mask。預設值:None;嘗試偵測因果遮罩。警告:tgt_is_causal 提供一個提示,表示 tgt_mask 是因果遮罩。提供不正確的提示可能會導致不正確的執行,包括向前和向後相容性。

  • memory_is_causal (bool) – 如果指定,則將因果遮罩應用為 memory mask。預設值:False。警告:memory_is_causal 提供一個提示,表示 memory_mask 是因果遮罩。提供不正確的提示可能會導致不正確的執行,包括向前和向後相容性。

回傳類型

Tensor

形狀

請參閱 Transformer 中的文件。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源