快捷方式

torch.nn.functional.one_hot

torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor

接受形狀為 (*) 的 LongTensor,其中包含索引值,並傳回形狀為 (*, num_classes) 的張量,除了最後一個維度的索引與輸入張量的對應值相符的位置之外,其他地方都是零,在這種情況下,它將是 1。

另請參閱 維基百科上的 One-hot

參數
  • tensor (LongTensor) – 任何形狀的類別值。

  • num_classes (int) – 類別總數。如果設定為 -1,類別數量將被推斷為比輸入張量中最大的類別值大一。

回傳

LongTensor,其維度增加一維,且最後一維的索引位置由輸入指定,該位置的值為 1,其他位置的值為 0。

範例

>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])
>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
         [0, 1, 0]],
        [[0, 0, 1],
         [1, 0, 0]],
        [[0, 1, 0],
         [0, 0, 1]]])

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得針對初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得您問題的解答

檢視資源