torch.masked_select¶
- torch.masked_select(input, mask, *, out=None) Tensor ¶
根據布林遮罩
mask
(一個 BoolTensor) 索引input
張量,並傳回一個新的 1 維張量。mask
張量和input
張量的形狀不需要匹配,但它們必須是 可廣播的。注意
傳回的張量不使用與原始張量相同的儲存空間
- 參數
input (Tensor) – 輸入張量。
mask (BoolTensor) – 包含用於索引的二元遮罩的張量
- 關鍵字參數
out (Tensor, optional) – 輸出張量 (可選)。
範例
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], [-1.2035, 1.2252, 0.5002, 0.6248], [ 0.1307, -2.0608, 0.1244, 2.0139]]) >>> mask = x.ge(0.5) >>> mask tensor([[False, False, False, False], [False, True, True, True], [False, False, False, True]]) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139])