捷徑

torch.quantile

torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) Tensor

計算沿著維度 diminput 張量的每個列的 q 分位數。

為了計算分位數,我們會將 [0, 1] 範圍內的 q 映射到索引 [0, n] 的範圍,以在排序後的輸入中找到分位數的位置。如果分位數落在兩個數據點 a < b 之間,它們在排序後的順序中分別具有索引 ij,則根據給定的 interpolation 方法計算結果,如下所示:

  • linear: a + (b - a) * fraction,其中 fraction 是計算出的分位數索引的小數部分。

  • lower: a

  • higher: b

  • nearest: ab,以索引更接近計算出的分位數索引的那個為準(.5 的小數部分向下捨入)。

  • midpoint: (a + b) / 2

如果 q 是一個 1D 張量,則輸出的第一個維度表示分位數,其大小等於 q 的大小,其餘維度是降維後剩下的維度。

注意

預設情況下,dimNone,導致在計算之前將 input 張量展平。

參數
  • input (Tensor) – 輸入張量。

  • q (floatTensor) – 範圍 [0, 1] 中的純量或 1D 張量值。

  • dim (int) – 要縮減的維度。

  • keepdim (bool) – 輸出張量是否保留 dim

關鍵字引數
  • interpolation (str) – 當所需的分位數落在兩個數據點之間時使用的插值方法。可以是 linearlowerhighermidpointnearest。預設值為 linear

  • out (Tensor, optional) – 輸出張量。

範例

>>> a = torch.randn(2, 3)
>>> a
tensor([[ 0.0795, -1.2117,  0.9765],
        [ 1.1707,  0.6706,  0.4884]])
>>> q = torch.tensor([0.25, 0.5, 0.75])
>>> torch.quantile(a, q, dim=1, keepdim=True)
tensor([[[-0.5661],
        [ 0.5795]],

        [[ 0.0795],
        [ 0.6706]],

        [[ 0.5280],
        [ 0.9206]]])
>>> torch.quantile(a, q, dim=1, keepdim=True).shape
torch.Size([3, 2, 1])
>>> a = torch.arange(4.)
>>> a
tensor([0., 1., 2., 3.])
>>> torch.quantile(a, 0.6, interpolation='linear')
tensor(1.8000)
>>> torch.quantile(a, 0.6, interpolation='lower')
tensor(1.)
>>> torch.quantile(a, 0.6, interpolation='higher')
tensor(2.)
>>> torch.quantile(a, 0.6, interpolation='midpoint')
tensor(1.5000)
>>> torch.quantile(a, 0.6, interpolation='nearest')
tensor(2.)
>>> torch.quantile(a, 0.4, interpolation='nearest')
tensor(1.)

文件

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources