捷徑

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor

沿著 dim 指定的軸收集值。

對於 3 維張量,輸出由下式指定

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

inputindex 必須具有相同數量的維度。 也必須 index.size(d) <= input.size(d) 用於所有維度 d != dim. out 將具有與 index 相同的形狀。 請注意,inputindex 不會相互廣播。

參數
  • input ( Tensor) – 來源張量

  • dim ( int) – 要進行索引的軸

  • index (LongTensor) – 要收集的元素的索引

關鍵字參數
  • sparse_grad ( bool, optional) – 如果 True,則相對於 input 的梯度將會是一個稀疏張量。

  • out ( Tensor, optional) – 目標張量

範例

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適合初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並取得您問題的解答

檢視資源