torch.nn.functional.one_hot¶
- torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor ¶
接受形狀為
(*)
的 LongTensor,其中包含索引值,並傳回形狀為(*, num_classes)
的張量,除了最後一個維度的索引與輸入張量的對應值相符的位置之外,其他地方都是零,在這種情況下,它將是 1。另請參閱 維基百科上的 One-hot 。
- 參數
tensor (LongTensor) – 任何形狀的類別值。
num_classes (int) – 類別總數。如果設定為 -1,類別數量將被推斷為比輸入張量中最大的類別值大一。
- 回傳
LongTensor,其維度增加一維,且最後一維的索引位置由輸入指定,該位置的值為 1,其他位置的值為 0。
範例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])