快捷方式

torch.Tensor.masked_scatter_

Tensor.masked_scatter_(mask, source)

source 中的元素複製到 self 張量的位置,條件是 mask 為 True。從 source 複製元素到 self,從 source 的位置 0 開始,並按照順序,對每個 mask 為 True 的情況,逐一進行複製。mask 的形狀必須可以與底層張量的形狀進行廣播 (broadcastable)source 應該至少有與 mask 中 1 的數量一樣多的元素。

參數
  • mask (BoolTensor) – 布林遮罩

  • source (Tensor) – 要複製來源的張量

注意

mask 操作於 self 張量,而不是給定的 source 張量。

範例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources