scatter operations missing index locations
It will be great to return the index locations from the scatter operations like torch_scatter: `from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_zeros((2, 6))
out, argmax = scatter_max(src, index, out=out) ` https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/max.html
Is this an upcoming feature ?
There's no plan to add this at the moment. A scatter_argmax might be doable, but I'm curious what would you use it for?
this is used in some graphs architecture to speed up computation (one call instead of 2) and not available in mlx-graphs repo https://github.com/mlx-graphs/mlx-graphs/blob/4619d97ade4a788d1fe20b4d08b994ad56ee5ca0/mlx_graphs/utils/scatter.py#L10