捷徑

torch.nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[原始碼][原始碼]

填充已封裝的變長序列批次。

它是 pack_padded_sequence() 的反向操作。

回傳的 Tensor 的資料大小將為 T x B x * (如果 batch_firstFalse) 或 B x T x * (如果 batch_firstTrue),其中 T 是最長序列的長度,B 是批次大小。

範例

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

注意

total_lengthModule 中使用 pack sequence -> recurrent network -> unpack sequence 模式並用 DataParallel 包裝時很有用。詳情請參閱 此常見問題解答

參數
  • sequence (PackedSequence) – 要填充的批次

  • batch_first (bool, optional) – 如果 True,則輸出將採用 B x T x * 格式,否則採用 T x B x * 格式。

  • padding_value (float, optional) – 填充元素的值。

  • total_length (int, optional) – 如果不是 None,則輸出將被填充至長度 total_length。 如果 total_length 小於 sequence 中的最大序列長度,此方法將拋出 ValueError

回傳

包含已填充序列的 Tensor,以及包含批次中每個序列長度列表的 Tensor 的元組。 批次元素將按照原始順序重新排序,即批次傳遞給 pack_padded_sequencepack_sequence 時的順序。

回傳類型

Tuple[Tensor, Tensor]

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源