array-api icon indicating copy to clipboard operation
array-api copied to clipboard

RFC: static vs. dynamic shapes and JAX's `.at` for simulating in-place ops

Open rgommers opened this issue 1 year ago • 7 comments

This is a continuation of a discussion that started a few weeks ago in gh-597 (Cc @soraros). It is closely related to gh-84 (boolean indexing) and gh-24 (mutability and copy/views).

I'll copy the content of @soraros's comment here in full:

Start of comment

I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is

>>> xs[ix_bool]
array([0, 2, 4])

Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend x[ix_bool] += 1 is syntax sugar for x = x + where(ix_bool, 1, 0) (which works in JAX) for a moment. The same problem appears when we want x[ix_bool] += [1, 3, 5]. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape of xs[ix_bool] as in the last example.

So what we really work around is the static shape requirement (recall the need of a size parameter for nonzero), which is not exclusively JAX.

Now, for something a bit off-topic.:

I think the JAX style functional syntax a = a.at[...].set(...) for in-place operation looks (and arguably works) better than numpy, and I'd really like to have it for array api. Some pros:

  • Looks familiar, and simulates the feel of in-place operation just fine.
  • Made it clear nothing is modified. This restricted access pattern would work with any accelerator-backed system. I think it would aid static analysis in system like Numba as well.
  • More concise, can be chained, and sometimes express our intention better.
a = zeros(m)       # initialing a
a[I] += arange(n)  # semantically, still initialing a

# VS

# being concise here is not the important point
# this line becomes a "semantical block" for initialisation
a = zeros(m).at[I].add(arange(n))  # initialing a
  • Can specify indexing mode, (more) easily.
# I think these are fairly cumbersome to represent in `numpy`, as we don't have kwargs for __getitem__
b = a.at[I].add(val, unique_indices=True)     # important info for accelerators
c = b.at[J].get(mode='fill', fill_value=nan)  # sure, we have `take`, but this is uniform and cool

Some of my thoughts

  • The last two lines of code, annotating getitem/setitem-like operations with info for accelerators, is an argument that hasn't been made before. If that's something we'd want to support, then this is a way to do it. A context manager would be another way, or like Numba does it (e.g., a boundscheck keyword to @njit).
  • As discussed in gh-24, the syntax for x = x.at[... and numpy et al.'s in-place support is completely equivalent when you have a JIT, and numpy's version is more efficient if you don't - as long as you can guarantee that you are not modifying a view. The syntax is also arguably nicer - more concise and more familiar. So, from that perspective, .at isn't ideal.
  • It seems like we do need better static shape support though. The dynamic shape support is marked as optional in the standard, so what's the alternative?

The last point is important. Writing generic code is difficult now when you need, e.g., update values with a mask. Doing that only the JAX way seems like a nonstarter, because it's way too inefficient for NumPy et al. The question though is if there's something that would work for JAX, TF and Dask? Dask also struggles to some extent with dynamic shapes, although most of it now works (xref https://github.com/dask/dask/issues/2000 and https://github.com/dask/dask/pull/7393). @jakirkham any thoughts on whether you need anything more (possibly JAX-like) for dynamic shape support in Dask?

rgommers avatar Mar 09 '23 00:03 rgommers

For completeness, let me copy the comparison between syntax's from https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html:

image

rgommers avatar Mar 09 '23 00:03 rgommers

I'm curious if the Jax developers have thought about what would be needed for Python's syntax to allow the more readable x += x[i] but to still have the same functionality. Maybe https://peps.python.org/pep-0637/ (keyword arguments in getitem)? Or would you also need more than that (like a += walrus operator or something, I don't know)?

asmeurer avatar Mar 09 '23 19:03 asmeurer

For what it's worth, x += x[i] does work in JAX, and is compatible with JIT. JAX arrays don't override __iadd__, so Python falls back to essentially x = x + x[i].

Since JAX arrays do not have mutable view semantics, this is not at all problematic.

jakevdp avatar Jun 21 '23 18:06 jakevdp

I'm curious if the Jax developers have thought about what would be needed for Python's syntax to allow the more readable x += x[i] but to still have the same functionality. Maybe https://peps.python.org/pep-0637/ (keyword arguments in getitem)? Or would you also need more than that (like a += walrus operator or something, I don't know)?

Seems like the tricky case would be x[i] += y. This fails silently with NumPy if x[i] does not create a view. In contrast, x.at[i].add(y) always works.

shoyer avatar Jun 21 '23 19:06 shoyer

Seems like the tricky case would be x[i] += y. This fails silently with NumPy if x[i] does not create a view. In contrast, x.at[i].add(y) always works.

This seems to be the only unsolved problem for testing JAX arrays inside SciPy over at https://github.com/scipy/scipy/pull/20085. Lots of these cases occur already in the small portion of the code base which has been ported to array API compatibility.

lucascolley avatar Apr 06 '24 23:04 lucascolley

Out of curiosity, I tried to run the existing array API test suite of scikit-learn with jax and many of the tests failed because of the inplace assignment limitation of jax making this mostly useless in its current state:

  • https://github.com/scikit-learn/scikit-learn/pull/29647

ogrisel avatar Aug 09 '24 16:08 ogrisel

For anyone following this issue but not my SciPy JAX PR, https://github.com/scipy/scipy/pull/20085#issuecomment-2162425667 is (I think) as far as we got with this

lucascolley avatar Aug 09 '24 16:08 lucascolley