mlx icon indicating copy to clipboard operation
mlx copied to clipboard

scatter operations missing index locations

Open sarmientoF opened this issue 1 year ago • 2 comments

It will be great to return the index locations from the scatter operations like torch_scatter: `from torch_scatter import scatter_max

src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_zeros((2, 6))

out, argmax = scatter_max(src, index, out=out) ` https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/max.html

Is this an upcoming feature ?

sarmientoF avatar Nov 27 '24 13:11 sarmientoF

There's no plan to add this at the moment. A scatter_argmax might be doable, but I'm curious what would you use it for?

awni avatar Nov 27 '24 20:11 awni

this is used in some graphs architecture to speed up computation (one call instead of 2) and not available in mlx-graphs repo https://github.com/mlx-graphs/mlx-graphs/blob/4619d97ade4a788d1fe20b4d08b994ad56ee5ca0/mlx_graphs/utils/scatter.py#L10

thegodone avatar Nov 28 '24 05:11 thegodone