mlx
mlx copied to clipboard
Python bindings for scatter operations
Proposed changes
Expose scatter
operations to the python API.
Any help on completing this is welcome.
It's up for discussion whether we should expose only the scatter
operations with different modes or also the various scatter_add
, scatter_prod
,...
Checklist
Put an x
in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)
Ideally, we need examples in docs and tests for the binding. I will add some tests.
Hi, thanks for starting this, that's great! However if we are to expose scatter
it would be similar to C++ via scatter_add
, scatter_prod
etc. Not via a general scatter that accepts a mode parameter.
Additionally, indexing is generally a more friendly interface so it might be worth mapping those to x[indices] += y
rather than scatter_add
for instance. Albeit as I mentioned in #394 it would require several shenanigans so it might not be worth it.
The most probably future path, imho, is to expose scatter_{op}
so I think the work you are doing is not in vain but I am writing so that we can coordinate.
Hi @angeloskath, great - we're iterating on exposing each scatter_{op}
. Do you think having also the generic scatter
with the mode
(or reduce
) parameter (as @gboduljak implemented in https://github.com/francescofarina/mlx/pull/1) would be handy? That's often available in other libraries.
The most probably future path, imho, is to expose
scatter_{op}
so I think the work you are doing is not in vain but I am writing so that we can coordinate.
As per @angeloskath's suggestion, I will work on exposing scatter_{op}
instead of generic scatter
.
As per @angeloskath's suggestion, I will work on exposing
scatter_{op}
instead of genericscatter
.
This is now done. @francescofarina please see https://github.com/francescofarina/mlx/pull/2/.
I'm slightly confused, there are two ongoing PRs for scatter ops in MLX 🤔 (#394 has bindings as well). It seems like they are going for different APIs, but o/w mostly the same?
Looks like this one can be safely closed and we can move with https://github.com/ml-explore/mlx/pull/394 which has a different API. @angeloskath @gboduljak ?
Yeah, sorry for stepping on your toes guys.
It kinda evolved from writing the gradients for scatter
and scatter_add
and then realizing that scatter_{op}
would be kinda hard for people to understand (or rather me to document properly) so I implemented the array.at
interface which is much simpler to understand.
Closing this in favor of #394. @francescofarina @gboduljak your feedback on #394 is appreciated.