mlx
mlx copied to clipboard
Scatter vjp
Implement the vjp of scatter
and scatter_add
. It also adds the scatter_add
op in python for initial testing as
x[indices] += y
can not map to scatter add but it can only map to
x[indices] = x[indices] + y
The main question is whether we can automatically detect the above and replace it with scatter_add
or scatter_{op}
depending on the operation. It might also simply not be worth the trouble and we could expose scatter_{op}
and leave it to the user to utilize them for these updates.
Remaining tasks
- [ ] Write tests in C++ and python
- [ ] Write docs if we choose to expose scatter_{op}
- [ ] Coordinate with #391 for the implementation of the above
The main question is whether we can automatically detect the above and replace it with
scatter_add
orscatter_{op}
depending on the operation. It might also simply not be worth the trouble and we could exposescatter_{op}
and leave it to the user to utilize them for these updates.
Can we implement rewriting pass through the graph within simplify
?
We can but it is not that simple as they are not equivalent operations. The scatter_add
is atomic while the other one isn't (still atomic but not the increment). Simplify must not change the result of the graph.
As an aside, simplify must remain light so it can be run on the dynamic graph every time. As a result I would be very cautious adding checks there that only may benefit specific situations.
If we do decide that this should be handled with a graph rewrite I think the proper place for that is the python implementation of set item (mlx_set_item
) but maybe it isn't worth it.
We can but it is not that simple as they are not equivalent operations. The
scatter_add
is atomic while the other one isn't (still atomic but not the increment). Simplify must not change the result of the graph.As an aside, simplify must remain light so it can be run on the dynamic graph every time. As a result I would be very cautious adding checks there that only may benefit specific situations.
If we do decide that this should be handled with a graph rewrite I think the proper place for that is the python implementation of set item (
mlx_set_item
) but maybe it isn't worth it.
Thanks for the elaborate answer. To determine whether it is worth doing the rewrite, we can measure performance difference between the two operations.
Honestly the more I think about it I think this shouldn't be implemented. It would also break numpy compatibility which implements sliced __iadd__
the same way.
The numpy way to do scatter add is to use np.add.at
so maybe we should provide something like that or just the scatter add ops which are common in ML frameworks.
Honestly the more I think about it I think this shouldn't be implemented. It would also break numpy compatibility which implements sliced
__iadd__
the same way.The numpy way to do scatter add is to use
np.add.at
so maybe we should provide something like that or just the scatter add ops which are common in ML frameworks.
Can we simply not expose scatter
and implement np.add.at
and similar? If we go for np.add.at
equivalents, do we have existing kernels or other primitive operations for an efficient implementation?
Honestly the more I think about it I think this shouldn't be implemented. It would also break numpy compatibility which implements sliced
__iadd__
the same way. The numpy way to do scatter add is to usenp.add.at
so maybe we should provide something like that or just the scatter add ops which are common in ML frameworks.Can we simply not expose
scatter
and implementnp.add.at
and similar? If we go fornp.add.at
equivalents, do we have existing kernels or other primitive operations for an efficient implementation?
Alternatively, we can implement operations such as https://pytorch-scatter.readthedocs.io/en/1.3.0/. Those are sufficient to support efficient implementation of GNNs.
So this is ready for review, I started writing the documentation for scatter_add
and quickly realized that there is a reason why people create the higher level APIs. They are equally powerful but much simpler to reason about. I removed scatter_add
and implemented all the scatter ops using the array.at
pattern.
Simply put
x[idx] = y # maps to scatter
x = x.at[idx].add(y) # maps to scatter_add
x = x.at[idx].subtract(y) # maps to scatter_add
x = x.at[idx].multiply(y) # maps to scatter_prod
... etc ...
Before merging it probably needs more tests and also some docs regarding the at
pattern.
Oh also , @awni , I realize that it was already reviewed but I would appreciate another look since the main paradigm changed in the meantime.
I'm wondering if we should make these in place or not. For example: x[0].add(...)
could be an in place op. It's not so much that I feel it should be one way or another but right now there are subtle inconsistencies that I am concerned about. For example:
>>> a = mx.array([0])
>>> b = a
>>> b[0] = 1
>>> a
array([1], dtype=int32)
>>> a = mx.array([0])
>>> b = a
>>> b += 1
>>> a
array([0], dtype=int32)
One option is to get rid of b[0] = 1
and use b.at[0].set(1)
and do it out of place. But I also really love the clean syntax of b[0] = 1
and would be sad to see it go.
To be fair I think the inconsistency here is the b += 1
. The rest are pretty expected. The solution here would be to implement __iadd__
and friends so that people can do x += 1
and get what they expect.
Regarding, get
and set
on the array.at
helper I think they should definitely not be the only way to do those operations. Personally I don't like it when there are many ways to do a thing, however I can think of implementations such as
def update(x, idx, value, op):
return getattr(x.at[idx], op)(value)
where having set
as an option would simplify the code.
To be fair I think the inconsistency here is the b += 1
💯
Yes I realized my comment was a slight digression but this seemed like a good place to whine about it 😛
- Making ops on the proxy from
.at[]
out of place seems like a good approach and gives added flexibility over the in place alternative for set at least - We should update the in place ops to be consistent (not suggesting for this PR but should be done soon)
So this is ready for review, I started writing the documentation for
scatter_add
and quickly realized that there is a reason why people create the higher level APIs. They are equally powerful but much simpler to reason about. I removedscatter_add
and implemented all the scatter ops using thearray.at
pattern.Simply put
x[idx] = y # maps to scatter x = x.at[idx].add(y) # maps to scatter_add x = x.at[idx].subtract(y) # maps to scatter_add x = x.at[idx].multiply(y) # maps to scatter_prod ... etc ...
Before merging it probably needs more tests and also some docs regarding the
at
pattern.
I think this looks clean and nice. Only comment is that, coming from other frameworks people may be used to using scatter operations explicitly. So it would be great to add some extra docs mentioning the parallel (or referring to the underlying c++ implementation).