torch.sort¶
- torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)¶
沿著給定的維度,以遞增的數值順序對
input
張量的元素進行排序。如果未提供
dim
,則選擇 input 的最後一個維度。如果
descending
為True
,則元素將以遞減的數值順序排序。如果
stable
為True
,則排序程序將變為穩定排序,從而保留等效元素的順序。回傳一個 (values, indices) 的 namedtuple,其中 values 是排序後的數值,而 indices 是原始 input 張量中元素的索引。
- 參數
- 關鍵字參數
out (tuple, 可選) – (Tensor, LongTensor) 的輸出元組,可以選擇性地提供以用作輸出緩衝區
範例
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]]) >>> x = torch.tensor([0, 1] * 9) >>> x.sort() torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) >>> x.sort(stable=True) torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))