Unflatten¶
- class torch.nn.Unflatten(dim, unflattened_size)[原始碼][原始碼]¶
將 tensor 的 dim 展開為所需的形狀。用於
Sequential
。dim
指定要展開的輸入 tensor 的維度,當使用 Tensor 或 NamedTensor 時,它可以是 int 或 str。unflattened_size
是 tensor 展開維度的新形狀,對於 Tensor 輸入,它可以是 tuple 的 ints 或 list 的 ints 或 torch.Size;對於 NamedTensor 輸入,它是 NamedShape((name, size) tuples 的 tuple)。
- 形狀
輸入: , 其中 是維度
dim
的大小,而 表示任意數量的維度,包括沒有維度。輸出: , 其中 =
unflattened_size
且 。
- 參數
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 展開維度的新形狀
範例
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])