快捷鍵

torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[來源][來源]

封裝包含可變長度填充序列的 Tensor。

input 的大小可以是 T x B x * (如果 batch_firstFalse) 或是 B x T x * (如果 batch_firstTrue),其中 T 是最長序列的長度,B 是批次大小,而 * 是任意數量的維度(包含 0)。

對於未排序的序列,請使用 enforce_sorted = False。如果 enforce_sortedTrue,則序列應按長度降序排序,即 input[:,0] 應該是最長的序列,而 input[:,B-1] 應該是最短的序列。enforce_sorted = True 僅在匯出 ONNX 時是必要的。

它是 pad_packed_sequence() 的反向操作,因此可以使用 pad_packed_sequence() 來恢復打包在 PackedSequence 中的底層張量。

注意

此函數接受至少具有兩個維度的任何輸入。 您可以應用它來打包標籤,並將 RNN 的輸出與它們一起使用以直接計算損失。 可以通過訪問其 .data 屬性,從 PackedSequence 物件中檢索張量。

參數
  • input (Tensor) – 可變長度序列的已填充批次。

  • lengths (Tensorlist(int)) – 每個批次元素的序列長度列表(如果作為張量提供,則必須在 CPU 上)。

  • batch_first (bool, optional) – 如果 True,則預期輸入為 B x T x * 格式,否則為 T x B x * 格式。

  • enforce_sorted (bool, optional) – 如果 True,則預期輸入包含按長度降序排序的序列。如果 False,則輸入將被無條件排序。預設值:True

返回值

一個 PackedSequence 物件

返回類型

PackedSequence

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和高級開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得您的問題解答

查看資源