torch.scatter_reduce¶
-
torch.
scatter_reduce
(input, dim, index, reduce, *, output_size=None) → Tensor¶ Reduces all values from the
input
tensor to the indices specified in theindex
tensor. For each value ininput
, its output index is specified by its index ininput
fordimension != dim
and by the corresponding value inindex
fordimension = dim
. The applied reduction for non-unique indices is defined via thereduce
argument ("sum"
,"prod"
,"mean"
,"amax"
,"amin"
). For non-existing indices, the output will be filled with the identity of the applied reduction (1 for"prod"
and 0 otherwise).It is also required that
index.size(d) == input.size(d)
for all dimensionsd
. Moreover, ifoutput_size
is defined the the values ofindex
must be between0
andoutput_size - 1
inclusive.For a 3-D tensor with
reduce="sum"
, the output is given as:out[index[i][j][k]][j][k] += input[i][j][k] # if dim == 0 out[i][index[i][j][k]][k] += input[i][j][k] # if dim == 1 out[i][j][index[i][j][k]] += input[i][j][k] # if dim == 2
Note
This out-of-place operation is similar to the in-place versions of
scatter_()
andscatter_add_()
, in which the output tensor is automatically created according to the maximum values inindex
and filled based on the identity of the applied reduction.Note
This operation may behave nondeterministically when given tensors on a CUDA device. See Reproducibility for more information.
- Parameters
input (Tensor) – the input tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to scatter and reduce.
src (Tensor) – the source elements to scatter and reduce
reduce (str) – the reduction operation to apply for non-unique indices (
"sum"
,"prod"
,"mean"
,"amax"
,"amin"
)output_size (int, optional) – the size of the output at dimension
dim
. If set toNone
, will get automatically inferred according toindex.max() + 1
Example:
>>> input = torch.tensor([1, 2, 3, 4, 5, 6]) >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) >>> torch.scatter_reduce(input, 0, index, reduce="sum", output_size=3) tensor([4, 12, 5])