Array API: remove array mutation from functions/write JAX version of each array mutating function
Add your issue here
Conversation from Slack -
2 quick things from my initial experiments -
We will need custom JAX implementations for most of the functions as JAX arrays are immutable. See how -
- s2fft (each file has a corresponding *_jax file - Matt Price worked on this), and
- awkward (all other backends have common reducers but the JAX backend has specialized reducers - I worked on this) were developed independently, but they both have one common pattern, a standard implementation of a particular function and then a JAX implementation of the same function.
Given that the project's first half is more focussed on GPUs and less on AD, I experimented with CuPy, but their Array API support is still experimental (and I might have also stumbled on a bug - https://github.com/cupy/cupy/issues/8747).
I thought the only place in the code where we cannot trivially get around the mutation was in iternorm() but I haven't actually gone through and checked
Oh yes, or remove all the lines of code that are mutating arrays. Maybe I should check for all such lines using ruff.
So the mutation in iternorm would be quite organically fixed by the "deep dive" project outside the GPU work that I was mentioning. We definitely have a path forward there.
My inclination would be that we change all functions to not mutate arrays and then we can use jax and everything else in one go. Is there any disadvantage of this approach?
Only in the way iternorm currently works, since it always needs to discard a large array and reallocate a fresh large array that's exactly the same. But probably negligible in the grand scheme of things, within GLASS at least (the code predates use in GLASS).
For array mutations, we can either rewrite functions to build copy of the original arrays, or we can write a separate implementation for jax (using the at syntax).
The at syntax produces copies too, but it applies changes in-place whenever the function is JIT compiled, which is a really nice performance advantage.
None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.
I think having a separate JAX implementation makes the most sense here to leverage better memory management whenever functions are JIT compiled.
Ideally should be generic, but no worries if things have specialized implementations (we are not worried a lot about the JIT support) -- just add a dispatching mechanism to hide the complexity from users
I have started having a go at this in #699. It would be good to get your input @ntessore as there are different ways this could be done such as, create new functions for jax vs update existing ones and write completely new algorithms which don't require mutation or naively remove mutations on a line by line basis.