Unfold¶
- class torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)[source][來源]¶
從批次輸入 Tensor 中提取滑動局部區塊。
考量一個批次的
input
張量,其形狀為 ,其中 是批次維度, 是通道維度,而 代表任意的空間維度。 這個操作會將input
空間維度內每個滑動的kernel_size
大小的區塊,扁平化成一個 3Doutput
張量的列(也就是最後一個維度),其形狀為 ,其中 是每個區塊內值的總數(一個區塊有 個空間位置,每個位置包含一個 -通道的向量),而 是這種區塊的總數。其中 是由
input
的空間維度構成的(如上面的 所示),而 則遍歷所有空間維度。因此,索引
output
的最後一個維度(欄維度)會得到特定區塊內的所有值。padding
、stride
和dilation
參數指定如何檢索滑動區塊。stride
控制滑動區塊的步幅。padding
控制在重塑形狀之前,在每個維度的padding
個點上,在兩側添加的隱式零填充量。dilation
控制 kernel 點之間的間距;也稱為 à trous 演算法。 雖然很難描述,但這個 連結 有一個很好的視覺化展示了dilation
的作用。
- 參數
如果
kernel_size
、dilation
、padding
或stride
是一個 int 或長度為 1 的 tuple,則它們的值將在所有空間維度上複製。在兩個輸入空間維度的情況下,此操作有時稱為
im2col
。
注意
Fold
通過對所有包含區塊中的所有值求和來計算結果大張量中的每個組合值。Unfold
通過從大張量複製來提取本地區塊中的值。 因此,如果區塊重疊,它們彼此不是反函數。通常,folding 和 unfolding 操作的相關方式如下。 考慮使用相同參數建立的
Fold
和Unfold
實例>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params)
然後,對於任何(支援的)
input
張量,以下等式成立fold(unfold(input)) == divisor * input
其中
divisor
是一個僅取決於input
的形狀和 dtype 的張量>>> input_ones = torch.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones))
當
divisor
張量不包含零元素時,fold
和unfold
操作是彼此的反函數(最多到一個常數除數)。警告
目前,僅支援 4-D 輸入張量(批次化圖像類張量)。
- 形狀
輸入:
輸出:如同上述說明的
範例
>>> unfold = nn.Unfold(kernel_size=(2, 3)) >>> input = torch.randn(2, 5, 3, 4) >>> output = unfold(input) >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) >>> # 4 blocks (2x3 kernels) in total in the 3x4 input >>> output.size() torch.Size([2, 30, 4]) >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) >>> inp = torch.randn(1, 3, 10, 12) >>> w = torch.randn(2, 3, 4, 5) >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) >>> # or equivalently (and avoiding a copy), >>> # out = out_unf.view(1, 2, 7, 8) >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() tensor(1.9073e-06)