Tensor Views¶
PyTorch 允許張量成為現有張量的 View
。View 張量與其基礎張量共享相同的底層資料。支援 View
避免了顯式的資料複製,因此允許我們進行快速且記憶體效率高的重塑、切片和逐元素操作。
例如,若要取得現有張量 t
的 view,您可以呼叫 t.view(...)
。
>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr() # `t` and `b` share the same underlying data.
True
# Modifying view tensor changes base tensor as well.
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)
由於 view 與其基礎張量共享底層資料,因此如果您編輯 view 中的資料,它也會反映在基礎張量中。
通常,PyTorch 的操作 (op) 會回傳一個新的 tensor 作為輸出,例如 add()
。但在 view 操作的情況下,為了避免不必要的資料複製,輸出是輸入 tensor 的 view。建立 view 時不會發生資料移動,view tensor 只是改變了它解釋相同資料的方式。取得 contiguous tensor 的 view 可能會產生一個 non-contiguous tensor。使用者應格外注意,因為 contiguity 可能會對效能產生隱含的影響。transpose()
是一個常見的例子。
>>> base = torch.tensor([[0, 1],[2, 3]])
>>> base.is_contiguous()
True
>>> t = base.transpose(0, 1) # `t` is a view of `base`. No data movement happened here.
# View tensors might be non-contiguous.
>>> t.is_contiguous()
False
# To get a contiguous tensor, call `.contiguous()` to enforce
# copying data when `t` is not contiguous.
>>> c = t.contiguous()
作為參考,以下是 PyTorch 中 view 操作的完整列表
基本的 slicing 和 indexing 操作,例如
tensor[0, 2:, 1:7:2]
會回傳 basetensor
的 view,請參閱下面的註記。view_as_real()
split_with_sizes()
indices()
(僅限於 sparse tensor)values()
(僅限於 sparse tensor)
註記
當透過 indexing 存取 tensor 的內容時,PyTorch 遵循 Numpy 的行為,即 basic indexing 回傳 view,而 advanced indexing 回傳副本。透過 basic 或 advanced indexing 進行賦值 (assignment) 都是 in-place 的。請參閱 Numpy indexing 文件 中更多的範例。
此外,還值得一提的是一些具有特殊行為的操作
reshape()
、reshape_as()
和flatten()
可以回傳 view 或新的 tensor,使用者程式碼不應依賴它是否為 view。contiguous()
如果輸入 tensor 已經是 contiguous 的,則回傳本身,否則它會複製資料並回傳一個新的 contiguous tensor。
有關 PyTorch 內部實作的更詳細演練,請參閱 ezyang 的關於 PyTorch Internals 的 blogpost。