快捷方式

torch.unflatten

torch.unflatten(input, dim, sizes) Tensor

將輸入張量的一個維度擴展為多個維度。

參見

torch.flatten() 這個函式的反向操作。它將多個維度合併為一個維度。

參數
  • input (Tensor) – 輸入張量。

  • dim (int) – 要展開的維度,指定為 input.shape 中的索引。

  • sizes (Tuple[int]) – 展開維度的新形狀。其中一個元素可以是 -1,在這種情況下,將推斷對應的輸出維度。否則,sizes 的乘積必須等於 input.shape[dim]

返回值

輸入張量的一個視圖 (View),具有指定的展開維度。

範例:
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])

文件

訪問 PyTorch 的完整開發者文檔

查看文檔

教學

獲取針對初學者和高級開發者的深入教學

查看教學

資源

查找開發資源並獲得問題解答

查看資源