torch.repeat_interleave¶
- torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) Tensor ¶
重複張量的元素。
警告
這與
torch.Tensor.repeat()
不同,但類似於numpy.repeat
。- 參數
input (Tensor) – 輸入張量。
repeats (Tensor 或 int) – 每個元素的重複次數。 `repeats` 會被廣播以符合給定軸的形狀。
dim (int, 可選) – 沿哪個維度重複值。預設情況下,使用展平的輸入陣列,並返回一個展平的輸出陣列。
- 關鍵字參數
output_size (int, 可選) – 給定軸的總輸出大小(例如,重複次數的總和)。如果給定,它將避免計算張量輸出形狀所需的串流同步。
- 返回值
重複的張量,其形狀與輸入相同,除了沿給定軸。
- 返回類型
範例
>>> x = torch.tensor([1, 2, 3]) >>> x.repeat_interleave(2) tensor([1, 1, 2, 2, 3, 3]) >>> y = torch.tensor([[1, 2], [3, 4]]) >>> torch.repeat_interleave(y, 2) tensor([1, 1, 2, 2, 3, 3, 4, 4]) >>> torch.repeat_interleave(y, 3, dim=1) tensor([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) tensor([[1, 2], [3, 4], [3, 4]]) >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) tensor([[1, 2], [3, 4], [3, 4]])
如果 repeats 是 tensor([n1, n2, n3, …]),那麼輸出將是 tensor([0, 0, …, 1, 1, …, 2, 2, …, …]),其中 0 出現 n1 次,1 出現 n2 次,2 出現 n3 次,等等。
- torch.repeat_interleave(repeats, *) Tensor
重複 0 次 repeats[0] 次,重複 1 次 repeats[1] 次,重複 2 次 repeats[2] 次,等等。
範例
>>> torch.repeat_interleave(torch.tensor([1, 2, 3])) tensor([0, 1, 1, 2, 2, 2])