捷徑

torch.tensor_split

torch.tensor_split(input, indices_or_sections, dim=0) List of Tensors

將一個張量分割成多個子張量,所有這些子張量都是 input 的視圖 (view),沿著維度 dim 根據 indices_or_sections 指定的索引或區段數量進行分割。此函數基於 NumPy 的 numpy.array_split()

參數
  • input (Tensor) – 要分割的張量

  • indices_or_sections (Tensor, intlisttuple of ints) –

    如果 indices_or_sections 是一個整數 n 或一個值為 n 的零維長整數張量,則 input 沿著維度 dim 分割成 n 個區段。 如果 input 沿著維度 dim 可以被 n 整除,則每個區段的大小將相等,為 input.size(dim) / n。 如果 input 不能被 n 整除,則前 int(input.size(dim) % n) 個區段的大小將為 int(input.size(dim) / n) + 1,其餘區段的大小將為 int(input.size(dim) / n)

    如果 indices_or_sections 是整數列表或元組,或是一維長整數張量,則 input 會沿著維度 dim 在列表、元組或張量中的每個索引處進行分割。 例如,indices_or_sections=[2, 3]dim=0 將導致張量 input[:2]input[2:3]input[3:]

    如果 indices_or_sections 是一個張量,則它必須是 CPU 上的零維或一維長整數張量。

  • dim (int, optional) – 沿著其分割張量的維度。預設值:0

範例

>>> x = torch.arange(8)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))

>>> x = torch.arange(7)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
>>> torch.tensor_split(x, (1, 6))
(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))

>>> x = torch.arange(14).reshape(2, 7)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13]])
>>> torch.tensor_split(x, 3, dim=1)
(tensor([[0, 1, 2],
        [7, 8, 9]]),
 tensor([[ 3,  4],
        [10, 11]]),
 tensor([[ 5,  6],
        [12, 13]]))
>>> torch.tensor_split(x, (1, 6), dim=1)
(tensor([[0],
        [7]]),
 tensor([[ 1,  2,  3,  4,  5],
        [ 8,  9, 10, 11, 12]]),
 tensor([[ 6],
        [13]]))

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources