捷徑

torch.hsplit

torch.hsplit(input, indices_or_sections) List of Tensors

根據 indices_or_sections,將一個或多個維度的張量 input 水平分割成多個張量。每個分割都是 input 的視圖。

如果 input 是一維的,這等同於呼叫 torch.tensor_split(input, indices_or_sections, dim=0) (分割維度為零),如果 input 有兩個或更多維度,則等同於呼叫 torch.tensor_split(input, indices_or_sections, dim=1) (分割維度為 1),除非 indices_or_sections 是一個整數,它必須能均勻分割分割維度,否則會拋出執行時期錯誤。

這個函數基於 NumPy 的 numpy.hsplit()

參數
範例:
>>> t = torch.arange(16.0).reshape(4,4)
>>> t
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
>>> torch.hsplit(t, 2)
(tensor([[ 0.,  1.],
         [ 4.,  5.],
         [ 8.,  9.],
         [12., 13.]]),
 tensor([[ 2.,  3.],
         [ 6.,  7.],
         [10., 11.],
         [14., 15.]]))
>>> torch.hsplit(t, [3, 6])
(tensor([[ 0.,  1.,  2.],
         [ 4.,  5.,  6.],
         [ 8.,  9., 10.],
         [12., 13., 14.]]),
 tensor([[ 3.],
         [ 7.],
         [11.],
         [15.]]),
 tensor([], size=(4, 0)))

文件

取得 PyTorch 完整的開發者文件

檢視文件

教學課程

取得針對初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源