torch.split¶
- torch.split(tensor, split_size_or_sections, dim=0)[原始碼][原始碼]¶
將張量分割成多個區塊。每個區塊都是原始張量的一個檢視 (view)。
如果
split_size_or_sections
是一個整數型別,則tensor
將被分割成大小相等的區塊(如果可能的話)。 如果張量沿著給定的維度dim
的大小不能被split_size
整除,則最後一個區塊會比較小。如果
split_size_or_sections
是一個列表,則tensor
將被分割成len(split_size_or_sections)
個區塊,其在dim
維度上的大小將根據split_size_or_sections
來決定。- 參數
- 回傳型別
範例
>>> a = torch.arange(10).reshape(5, 2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) >>> torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))