torch.index_select¶
- torch.index_select(input, dim, index, *, out=None) Tensor ¶
返回一個新張量,該張量使用 LongTensor 中的條目
index
,沿維度dim
索引input
張量。返回的張量具有與原始張量 (
input
) 相同的維度數量。第dim
個維度的大小與index
的長度相同;其他維度的大小與原始張量中的相同。注意
返回的張量不使用與原始張量相同的儲存空間。 如果
out
具有與預期不同的形狀,我們會靜默地將其更改為正確的形狀,必要時重新分配底層儲存空間。- 參數
- 關鍵字引數
out (Tensor, optional) – 輸出張量。
範例
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])