torch.nn.functional.grid_sample¶
- torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)[原始碼][原始碼]¶
計算網格取樣 (grid sample)。
給定一個
input
和一個流場 (flow-field)grid
,使用input
的數值以及從grid
取得的像素位置來計算output
。目前僅支援空間 (4-D) 和體積 (5-D) 的
input
。在空間 (4-D) 的情況下,對於形狀為 的
input
和形狀為 的grid
,輸出將具有形狀 。對於每個輸出位置
output[n, :, h, w]
,大小為 2 的向量grid[n, h, w]
指定input
的像素位置x
和y
,這些位置被用來插值輸出值output[n, :, h, w]
。 在 5D 輸入的情況下,grid[n, d, h, w]
指定用於插值output[n, :, d, h, w]
的x
、y
、z
像素位置。mode
參數指定nearest
或bilinear
插值方法來取樣輸入像素。grid
指定了取樣像素的位置,這些位置已根據input
的空間維度進行了正規化。因此,它的大部分值應該在[-1, 1]
的範圍內。例如,值x = -1, y = -1
代表input
的左上角像素,而值x = 1, y = 1
代表input
的右下角像素。如果
grid
的值超出[-1, 1]
的範圍,則對應的輸出將按照padding_mode
的定義進行處理。選項如下:padding_mode="zeros"
:對於超出邊界的網格位置,使用0
。padding_mode="border"
:對於超出邊界的網格位置,使用邊界值。padding_mode="reflection"
:對於超出邊界的網格位置,使用邊界反射位置的值。對於遠離邊界的位置,它將持續反射直到進入邊界內,例如,(正規化的)像素位置x = -3.5
會以邊界-1
反射,變成x' = 1.5
,然後以邊界1
反射,變成x'' = -0.5
。
注意
此函數通常與
affine_grid()
結合使用,以構建 Spatial Transformer Networks。注意
當使用 CUDA 後端時,此操作可能會在其反向傳播中引起不確定性行為,而且不容易關閉。請參閱關於 重現性 的說明,以取得背景資訊。
注意
grid
中的 NaN 值將被解釋為-1
。- 參數
input (Tensor) – 形狀為 的輸入(4-D 情況)或 的輸入(5-D 情況)
grid (Tensor) – 形狀為 (4-D 的情況) 或 (5-D 的情況) 的 flow-field (流場)
mode (str) – 用於計算輸出值的插值模式
'bilinear'
|'nearest'
|'bicubic'
。預設值:'bilinear'
注意:mode='bicubic'
僅支援 4-D 輸入。當mode='bilinear'
且輸入為 5-D 時,內部使用的插值模式實際上是 trilinear。但是,當輸入為 4-D 時,插值模式確實是 bilinear。padding_mode (str) – 網格值之外的填充模式
'zeros'
|'border'
|'reflection'
。預設值:'zeros'
align_corners (bool, optional) – 在幾何上,我們將輸入的像素視為正方形而非點。如果設定為
True
,則極值 (-1
和1
) 被認為是指輸入角像素的中心點。如果設定為False
,則它們被認為是指輸入角像素的角點,從而使採樣更能適應解析度。此選項與interpolate()
中的align_corners
選項相同,因此此處使用的任何選項也應在那裡使用,以便在網格採樣之前調整輸入影像的大小。預設值:False
- Returns
輸出 Tensor
- Return type
output (Tensor)
Warning
當
align_corners = True
時,網格位置取決於像素大小相對於輸入影像大小,因此對於以不同解析度給定的相同輸入 (即,在經過向上取樣或向下取樣後),grid_sample()
採樣的位置將會不同。直到 1.2.0 版,預設行為都是align_corners = True
。從那時起,預設行為已更改為align_corners = False
,以便使其與interpolate()
的預設值保持一致。注意
mode='bicubic'
使用帶有 的 立方迴旋積演算法 (cubic convolution algorithm) 實作。常數 在不同的套件中可能不同。例如,PIL 和 OpenCV 分別使用 -0.5 和 -0.75。 此演算法可能會「超出 (overshoot)」它正在插值的數值範圍。 例如,當插值 [0, 255] 中的輸入時,它可能會產生負值或大於 255 的值。 使用torch.clamp()
鉗制 (clamp) 結果,以確保它們在有效範圍內。