捷徑

torch.topk

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)

沿著給定的維度,傳回給定 input tensor 的 k 個最大元素。

如果沒有給定 dim,則會選擇 input 的最後一個維度。

如果 largestFalse,則會回傳 k 個最小的元素。

會回傳一個具名元組 (values, indices),其中包含 input 張量在給定維度 dim 的每列中最大的 k 個元素的 valuesindices

如果布林選項 sortedTrue,則會確保回傳的 k 個元素本身已排序。

參數
  • input (Tensor) – 輸入張量。

  • k (int) – “top-k” 中的 k

  • dim (int, optional) – 要沿其排序的維度

  • largest (bool, optional) – 控制是否回傳最大或最小元素

  • sorted (bool, optional) – 控制是否以排序順序回傳元素

關鍵字引數

out (tuple, optional) – (Tensor, LongTensor) 的輸出元組,可以選擇性地提供以用作輸出緩衝區

範例

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源