torch.flatten¶
- torch.flatten(input, start_dim=0, end_dim=-1) Tensor ¶
透過將
input
重新塑形為一維張量來展平(Flatten)它。如果傳遞了start_dim
或end_dim
,則只會展平從start_dim
開始到end_dim
結束的維度。input
中元素的順序不會改變。與 NumPy 的 flatten 總是複製輸入資料不同,此函式可能會傳回原始物件、視圖或副本。如果沒有展平任何維度,則會傳回原始物件
input
。否則,如果輸入可以被視為展平後的形狀,則會傳回該視圖。最後,只有在輸入不能被視為展平後的形狀時,才會複製輸入的資料。有關何時傳回視圖的詳細資訊,請參閱torch.Tensor.view()
。注意
展平一個零維張量將會傳回一個一維視圖。
範例
>>> t = torch.tensor([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]])