torch.argwhere¶
- torch.argwhere(input) Tensor ¶
傳回一個張量,其中包含
input
中所有非零元素的索引。結果中的每一列包含input
中一個非零元素的索引。結果依字典順序排序,且最後一個索引變化最快 (C 風格)。如果
input
具有 個維度,則產生的索引張量out
的大小為 ,其中 是input
張量中非零元素的總數。注意
此函數類似於 NumPy 的 argwhere。
當
input
位於 CUDA 上時,此函數會導致主機-裝置同步。- 參數
{input} –
範例
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])