torch.Tensor.masked_scatter_¶
- Tensor.masked_scatter_(mask, source)¶
將
source
中的元素複製到self
張量的位置,條件是mask
為 True。從source
複製元素到self
,從source
的位置 0 開始,並按照順序,對每個mask
為 True 的情況,逐一進行複製。mask
的形狀必須可以與底層張量的形狀進行廣播 (broadcastable)。source
應該至少有與mask
中 1 的數量一樣多的元素。- 參數
mask (BoolTensor) – 布林遮罩
source (Tensor) – 要複製來源的張量
注意
mask
操作於self
張量,而不是給定的source
張量。範例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])