快捷方式

Unflatten

class torch.nn.Unflatten(dim, unflattened_size)[原始碼][原始碼]

將 tensor 的 dim 展開為所需的形狀。用於 Sequential

  • dim 指定要展開的輸入 tensor 的維度,當使用 TensorNamedTensor 時,它可以是 intstr

  • unflattened_size 是 tensor 展開維度的新形狀,對於 Tensor 輸入,它可以是 tuple 的 ints 或 list 的 ints 或 torch.Size;對於 NamedTensor 輸入,它是 NamedShape(name, size) tuples 的 tuple)。

形狀
  • 輸入: (,Sdim,)(*, S_{\text{dim}}, *), 其中 SdimS_{\text{dim}} 是維度 dim 的大小,而 * 表示任意數量的維度,包括沒有維度。

  • 輸出: (,U1,...,Un,)(*, U_1, ..., U_n, *), 其中 UU = unflattened_sizei=1nUi=Sdim\prod_{i=1}^n U_i = S_{\text{dim}}

參數
  • dim (Union[int, str]) – 要展開的維度

  • unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 展開維度的新形狀

範例

>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得問題解答

檢視資源