torch.roll¶
- torch.roll(input, shifts, dims=None) Tensor ¶
沿著給定的維度滾動(roll)張量
input
。 超出最後一個位置的元素會重新引入到第一個位置。 如果dims
為 None,則張量將在滾動之前被展平,然後恢復到原始形狀。- 參數
範例
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) >>> x tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> torch.roll(x, 1) tensor([[8, 1], [2, 3], [4, 5], [6, 7]]) >>> torch.roll(x, 1, 0) tensor([[7, 8], [1, 2], [3, 4], [5, 6]]) >>> torch.roll(x, -1, 0) tensor([[3, 4], [5, 6], [7, 8], [1, 2]]) >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) tensor([[6, 5], [8, 7], [2, 1], [4, 3]])