torch.diagonal_scatter¶
- torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) Tensor ¶
將
src
張量的值嵌入到input
張量的對角線元素中,嵌入時以dim1
和dim2
為參考維度。此函數會回傳一個具有全新儲存空間的張量;它不會回傳一個檢視 (view)。
參數
offset
控制要考慮哪條對角線如果
offset
= 0,則為主要對角線。如果
offset
> 0,則在主要對角線上方。如果
offset
< 0,則在主要對角線下方。
- 參數
注意
src
必須具有適當的大小,才能嵌入到input
中。具體來說,它應該與torch.diagonal(input, offset, dim1, dim2)
具有相同的形狀。範例
>>> a = torch.zeros(3, 3) >>> a tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) >>> torch.diagonal_scatter(a, torch.ones(3), 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> torch.diagonal_scatter(a, torch.ones(2), 1) tensor([[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]])