mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Python bindings for scatter operations

Open francescofarina opened this issue 1 year ago • 1 comments

Proposed changes

Expose scatter operations to the python API.

Any help on completing this is welcome.

It's up for discussion whether we should expose only the scatter operations with different modes or also the various scatter_add, scatter_prod,...

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

francescofarina avatar Jan 06 '24 20:01 francescofarina

Ideally, we need examples in docs and tests for the binding. I will add some tests.

gboduljak avatar Jan 06 '24 22:01 gboduljak

Hi, thanks for starting this, that's great! However if we are to expose scatter it would be similar to C++ via scatter_add, scatter_prod etc. Not via a general scatter that accepts a mode parameter.

Additionally, indexing is generally a more friendly interface so it might be worth mapping those to x[indices] += y rather than scatter_add for instance. Albeit as I mentioned in #394 it would require several shenanigans so it might not be worth it.

The most probably future path, imho, is to expose scatter_{op} so I think the work you are doing is not in vain but I am writing so that we can coordinate.

angeloskath avatar Jan 07 '24 10:01 angeloskath

Hi @angeloskath, great - we're iterating on exposing each scatter_{op}. Do you think having also the generic scatter with the mode (or reduce) parameter (as @gboduljak implemented in https://github.com/francescofarina/mlx/pull/1) would be handy? That's often available in other libraries.

francescofarina avatar Jan 07 '24 13:01 francescofarina

The most probably future path, imho, is to expose scatter_{op} so I think the work you are doing is not in vain but I am writing so that we can coordinate.

As per @angeloskath's suggestion, I will work on exposing scatter_{op} instead of generic scatter.

gboduljak avatar Jan 07 '24 20:01 gboduljak

As per @angeloskath's suggestion, I will work on exposing scatter_{op} instead of generic scatter.

This is now done. @francescofarina please see https://github.com/francescofarina/mlx/pull/2/.

gboduljak avatar Jan 07 '24 21:01 gboduljak

I'm slightly confused, there are two ongoing PRs for scatter ops in MLX 🤔 (#394 has bindings as well). It seems like they are going for different APIs, but o/w mostly the same?

awni avatar Jan 08 '24 18:01 awni

Looks like this one can be safely closed and we can move with https://github.com/ml-explore/mlx/pull/394 which has a different API. @angeloskath @gboduljak ?

francescofarina avatar Jan 08 '24 18:01 francescofarina

Yeah, sorry for stepping on your toes guys.

It kinda evolved from writing the gradients for scatter and scatter_add and then realizing that scatter_{op} would be kinda hard for people to understand (or rather me to document properly) so I implemented the array.at interface which is much simpler to understand.

angeloskath avatar Jan 08 '24 18:01 angeloskath

Closing this in favor of #394. @francescofarina @gboduljak your feedback on #394 is appreciated.

awni avatar Jan 08 '24 19:01 awni