快捷方式

torch.argwhere

torch.argwhere(input) Tensor

傳回一個張量,其中包含 input 中所有非零元素的索引。結果中的每一列包含 input 中一個非零元素的索引。結果依字典順序排序,且最後一個索引變化最快 (C 風格)。

如果 input 具有 nn 個維度,則產生的索引張量 out 的大小為 (z×n)(z \times n),其中 zzinput 張量中非零元素的總數。

注意

此函數類似於 NumPy 的 argwhere

input 位於 CUDA 上時,此函數會導致主機-裝置同步。

參數

{input}

範例

>>> t = torch.tensor([1, 0, 1])
>>> torch.argwhere(t)
tensor([[0],
        [2]])
>>> t = torch.tensor([[1, 0, 1], [0, 1, 1]])
>>> torch.argwhere(t)
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [1, 2]])

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您的問題解答

檢視資源