捷徑

torch.flatten

torch.flatten(input, start_dim=0, end_dim=-1) Tensor

透過將 input 重新塑形為一維張量來展平(Flatten)它。如果傳遞了 start_dimend_dim,則只會展平從 start_dim 開始到 end_dim 結束的維度。 input 中元素的順序不會改變。

與 NumPy 的 flatten 總是複製輸入資料不同,此函式可能會傳回原始物件、視圖或副本。如果沒有展平任何維度,則會傳回原始物件 input。否則,如果輸入可以被視為展平後的形狀,則會傳回該視圖。最後,只有在輸入不能被視為展平後的形狀時,才會複製輸入的資料。有關何時傳回視圖的詳細資訊,請參閱 torch.Tensor.view()

注意

展平一個零維張量將會傳回一個一維視圖。

參數
  • input (Tensor) – 輸入張量。

  • start_dim (int) – 要展平的第一個維度

  • end_dim (int) – 要展平的最後一個維度

範例

>>> t = torch.tensor([[[1, 2],
...                    [3, 4]],
...                   [[5, 6],
...                    [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources