torch.take_along_dim¶
- torch.take_along_dim(input, indices, dim=None, *, out=None) Tensor ¶
沿著給定的
dim
,從indices
的一維索引中,選擇input
中的值。如果
dim
為 None,則輸入陣列將被視為已展平為 1 維。沿著維度回傳索引的函式,例如
torch.argmax()
和torch.argsort()
,被設計成可以與此函式搭配使用。請參閱下面的範例。注意
此函式與 NumPy 的 take_along_axis 類似。另請參閱
torch.gather()
。- 參數
- 關鍵字參數
out (Tensor, 可選) – 輸出張量。
範例
>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) >>> max_idx = torch.argmax(t) >>> torch.take_along_dim(t, max_idx) tensor([60]) >>> sorted_idx = torch.argsort(t, dim=1) >>> torch.take_along_dim(t, sorted_idx, dim=1) tensor([[10, 20, 30], [40, 50, 60]])