快捷方式

torch.take_along_dim

torch.take_along_dim(input, indices, dim=None, *, out=None) Tensor

沿著給定的 dim,從 indices 的一維索引中,選擇 input 中的值。

如果 dim 為 None,則輸入陣列將被視為已展平為 1 維。

沿著維度回傳索引的函式,例如 torch.argmax()torch.argsort(),被設計成可以與此函式搭配使用。請參閱下面的範例。

注意

此函式與 NumPy 的 take_along_axis 類似。另請參閱 torch.gather()

參數
  • input (Tensor) – 輸入張量。

  • indices (tensor) – input 的索引。必須為 long dtype。

  • dim (int, 可選) – 沿著選擇的維度。

關鍵字參數

out (Tensor, 可選) – 輸出張量。

範例

>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]])
>>> max_idx = torch.argmax(t)
>>> torch.take_along_dim(t, max_idx)
tensor([60])
>>> sorted_idx = torch.argsort(t, dim=1)
>>> torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
        [40, 50, 60]])

文件

取得 PyTorch 的完整開發人員文件

查看文件

教學

取得針對初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並取得您的問題解答

查看資源