mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Scatter vjp

Open angeloskath opened this issue 1 year ago • 12 comments

Implement the vjp of scatter and scatter_add. It also adds the scatter_add op in python for initial testing as

x[indices] += y

can not map to scatter add but it can only map to

x[indices] = x[indices] + y

The main question is whether we can automatically detect the above and replace it with scatter_add or scatter_{op} depending on the operation. It might also simply not be worth the trouble and we could expose scatter_{op} and leave it to the user to utilize them for these updates.

Remaining tasks

  • [ ] Write tests in C++ and python
  • [ ] Write docs if we choose to expose scatter_{op}
  • [ ] Coordinate with #391 for the implementation of the above

angeloskath avatar Jan 07 '24 10:01 angeloskath

The main question is whether we can automatically detect the above and replace it with scatter_add or scatter_{op} depending on the operation. It might also simply not be worth the trouble and we could expose scatter_{op} and leave it to the user to utilize them for these updates.

Can we implement rewriting pass through the graph within simplify?

gboduljak avatar Jan 07 '24 21:01 gboduljak

We can but it is not that simple as they are not equivalent operations. The scatter_add is atomic while the other one isn't (still atomic but not the increment). Simplify must not change the result of the graph.

As an aside, simplify must remain light so it can be run on the dynamic graph every time. As a result I would be very cautious adding checks there that only may benefit specific situations.

If we do decide that this should be handled with a graph rewrite I think the proper place for that is the python implementation of set item (mlx_set_item) but maybe it isn't worth it.

angeloskath avatar Jan 07 '24 23:01 angeloskath

We can but it is not that simple as they are not equivalent operations. The scatter_add is atomic while the other one isn't (still atomic but not the increment). Simplify must not change the result of the graph.

As an aside, simplify must remain light so it can be run on the dynamic graph every time. As a result I would be very cautious adding checks there that only may benefit specific situations.

If we do decide that this should be handled with a graph rewrite I think the proper place for that is the python implementation of set item (mlx_set_item) but maybe it isn't worth it.

Thanks for the elaborate answer. To determine whether it is worth doing the rewrite, we can measure performance difference between the two operations.

gboduljak avatar Jan 07 '24 23:01 gboduljak

Honestly the more I think about it I think this shouldn't be implemented. It would also break numpy compatibility which implements sliced __iadd__ the same way.

The numpy way to do scatter add is to use np.add.at so maybe we should provide something like that or just the scatter add ops which are common in ML frameworks.

angeloskath avatar Jan 07 '24 23:01 angeloskath

Honestly the more I think about it I think this shouldn't be implemented. It would also break numpy compatibility which implements sliced __iadd__ the same way.

The numpy way to do scatter add is to use np.add.at so maybe we should provide something like that or just the scatter add ops which are common in ML frameworks.

Can we simply not expose scatter and implement np.add.at and similar? If we go for np.add.at equivalents, do we have existing kernels or other primitive operations for an efficient implementation?

gboduljak avatar Jan 07 '24 23:01 gboduljak

Honestly the more I think about it I think this shouldn't be implemented. It would also break numpy compatibility which implements sliced __iadd__ the same way. The numpy way to do scatter add is to use np.add.at so maybe we should provide something like that or just the scatter add ops which are common in ML frameworks.

Can we simply not expose scatter and implement np.add.at and similar? If we go for np.add.at equivalents, do we have existing kernels or other primitive operations for an efficient implementation?

Alternatively, we can implement operations such as https://pytorch-scatter.readthedocs.io/en/1.3.0/. Those are sufficient to support efficient implementation of GNNs.

gboduljak avatar Jan 07 '24 23:01 gboduljak

So this is ready for review, I started writing the documentation for scatter_add and quickly realized that there is a reason why people create the higher level APIs. They are equally powerful but much simpler to reason about. I removed scatter_add and implemented all the scatter ops using the array.at pattern.

Simply put

x[idx] = y  # maps to scatter
x = x.at[idx].add(y)  # maps to scatter_add
x = x.at[idx].subtract(y)  # maps to scatter_add
x = x.at[idx].multiply(y)  # maps to scatter_prod
... etc ...

Before merging it probably needs more tests and also some docs regarding the at pattern.

angeloskath avatar Jan 08 '24 08:01 angeloskath

Oh also , @awni , I realize that it was already reviewed but I would appreciate another look since the main paradigm changed in the meantime.

angeloskath avatar Jan 08 '24 08:01 angeloskath

I'm wondering if we should make these in place or not. For example: x[0].add(...) could be an in place op. It's not so much that I feel it should be one way or another but right now there are subtle inconsistencies that I am concerned about. For example:

>>> a = mx.array([0])
>>> b = a
>>> b[0] = 1
>>> a
array([1], dtype=int32)
>>> a = mx.array([0])
>>> b = a
>>> b += 1
>>> a
array([0], dtype=int32)

One option is to get rid of b[0] = 1 and use b.at[0].set(1) and do it out of place. But I also really love the clean syntax of b[0] = 1 and would be sad to see it go.

awni avatar Jan 08 '24 19:01 awni

To be fair I think the inconsistency here is the b += 1. The rest are pretty expected. The solution here would be to implement __iadd__ and friends so that people can do x += 1 and get what they expect.

Regarding, get and set on the array.at helper I think they should definitely not be the only way to do those operations. Personally I don't like it when there are many ways to do a thing, however I can think of implementations such as

def update(x, idx, value, op):
    return getattr(x.at[idx], op)(value)

where having set as an option would simplify the code.

angeloskath avatar Jan 08 '24 19:01 angeloskath

To be fair I think the inconsistency here is the b += 1

💯

Yes I realized my comment was a slight digression but this seemed like a good place to whine about it 😛

  • Making ops on the proxy from .at[] out of place seems like a good approach and gives added flexibility over the in place alternative for set at least
  • We should update the in place ops to be consistent (not suggesting for this PR but should be done soon)

awni avatar Jan 08 '24 19:01 awni

So this is ready for review, I started writing the documentation for scatter_add and quickly realized that there is a reason why people create the higher level APIs. They are equally powerful but much simpler to reason about. I removed scatter_add and implemented all the scatter ops using the array.at pattern.

Simply put

x[idx] = y  # maps to scatter
x = x.at[idx].add(y)  # maps to scatter_add
x = x.at[idx].subtract(y)  # maps to scatter_add
x = x.at[idx].multiply(y)  # maps to scatter_prod
... etc ...

Before merging it probably needs more tests and also some docs regarding the at pattern.

I think this looks clean and nice. Only comment is that, coming from other frameworks people may be used to using scatter operations explicitly. So it would be great to add some extra docs mentioning the parallel (or referring to the underlying c++ implementation).

francescofarina avatar Jan 08 '24 19:01 francescofarina