torch.Tensor.index_reduce_¶
- Tensor.index_reduce_(dim, index, source, reduce, *, include_self=True) Tensor ¶
透過使用
reduce
參數指定的縮減方式,將source
的元素累加到self
張量中,累加到index
中給定的索引。例如,如果dim == 0
、index[i] == j
、reduce == prod
且include_self == True
,則source
的第i
列會乘以self
的第j
列。如果include_self="True"
,則self
張量中的值會包含在縮減中;否則,會將累加到的self
張量中的列視為填充了縮減單位元素。source
的第dim
維度的大小必須與index
的長度相同(必須是一個向量),並且所有其他維度必須與self
相符,否則會引發錯誤。對於
reduce="prod"
且include_self=True
的 3-D 張量,輸出如下所示:self[index[i], :, :] *= src[i, :, :] # if dim == 0 self[:, index[i], :] *= src[:, i, :] # if dim == 1 self[:, :, index[i]] *= src[:, :, i] # if dim == 2
注意
當給定 CUDA 裝置上的張量時,此操作的行為可能不具決定性。有關更多資訊,請參閱再現性。
注意
此函式僅支援浮點張量。
警告
此函式目前處於 beta 階段,並且在不久的將來可能會發生變更。
- 參數
- 關鍵字引數
include_self (bool) – 是否將
self
張量中的元素包含在縮減中
範例
>>> x = torch.empty(5, 3).fill_(2) >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float) >>> index = torch.tensor([0, 4, 2, 0]) >>> x.index_reduce_(0, index, t, 'prod') tensor([[20., 44., 72.], [ 2., 2., 2.], [14., 16., 18.], [ 2., 2., 2.], [ 8., 10., 12.]]) >>> x = torch.empty(5, 3).fill_(2) >>> x.index_reduce_(0, index, t, 'prod', include_self=False) tensor([[10., 22., 36.], [ 2., 2., 2.], [ 7., 8., 9.], [ 2., 2., 2.], [ 4., 5., 6.]])