jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.numpy: put_along_axis

Open LaraFuhrmann opened this issue 2 years ago • 11 comments

Hi :), I am using jax.numpy to define a NumPyro model. To do so it would be very useful to have a function doing essentially the same as numpy's put_along_axis. In fact, for now I modified numpy's code slightly to output a jax object. But I was wondering, if it is planned to add this feature to jax.numpy in the future?

LaraFuhrmann avatar Mar 18 '22 10:03 LaraFuhrmann

We haven't implemented this because the semantics of np.put_along_axis are to modify the array in-place, and this is not possible in JAX because JAX arrays are immutable. I suspect you could accomplish what you want to do using some combination of index update operators; see https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at

jakevdp avatar Mar 18 '22 12:03 jakevdp

I see - thanks a lot for the explanation :) I was able to implement a custom version of np.put_along_axis using the index update operators as you suggested. Thanks again!

LaraFuhrmann avatar Mar 21 '22 15:03 LaraFuhrmann

As there is some interest in this custom version that I am using for my problem, I am sharing the code here: (Happy to any suggestions on how to solve this better)

"""
Defining put_along_axis() from numpy for jax.numpy.
Essentially copied the code from
https://github.com/numpy/numpy/blob/4adc87dff15a247e417d50f10cc4def8e1c17a03/numpy/lib/shape_base.py#L29

"""

import numpy.core.numeric as _nx
import jax

def _make_along_axis_idx(arr_shape, indices, axis):
    # compute dimensions to iterate over
    if not _nx.issubdtype(indices.dtype, _nx.integer):
        raise IndexError("`indices` must be an integer array")
    if len(arr_shape) != indices.ndim:
        raise ValueError("`indices` and `arr` must have the same number of dimensions")
    shape_ones = (1,) * indices.ndim
    dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))

    # build a fancy index, consisting of orthogonal aranges, with the
    # requested index inserted at the right location
    fancy_index = []
    for dim, n in zip(dest_dims, arr_shape):
        if dim is None:
            fancy_index.append(indices)
        else:
            ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
            fancy_index.append(_nx.arange(n).reshape(ind_shape))

    return tuple(fancy_index)

def custom_put_along_axis(arr, indices, values, axis):
    """
    Parameters
    ----------
    arr : ndarray (Ni..., M, Nk...)
        Destination array.
    indices : ndarray (Ni..., J, Nk...)
        Indices to change along each 1d slice of `arr`. This must match the
        dimension of arr, but dimensions in Ni and Nj may be 1 to broadcast
        against `arr`.
    values : array_like (Ni..., J, Nk...)
        values to insert at those indices. Its shape and dimension are
        broadcast to match that of `indices`.
    axis : int
        The axis to take 1d slices along. If axis is None, the destination
        array is treated as if a flattened 1d view had been created of it.

    """

    # normalize inputs
    if axis is None:
        arr = arr.flat
        axis = 0
        arr_shape = (len(arr),)  # flatiter has no .shape
    else:
        # axis = normalize_axis_index(axis, arr.ndim)
        arr_shape = arr.shape

    # use the fancy index
    arr = arr.at[tuple(_make_along_axis_idx(arr_shape, indices, axis))].set(values)
    return arr

LaraFuhrmann avatar Apr 19 '22 07:04 LaraFuhrmann

One idea may be to define jax.numpy.put_along_axis, but add an extra inplace keyword that defaults to True, such that the function errors with the default value. Users could set inplace=False and then have a version of the function that returns the updated array. It would avoid the potential pitfall of users assuming the function works in-place and then being confused why the array isn't changing. What do you think?

I suspect we could refactor things to share much of the index processing with take_along_axis.

jakevdp avatar Apr 20 '22 20:04 jakevdp

Hi @jakevdp I would like to contribute to this ticket.

I will break down this ticket by the following, see if it makes sense:

  1. refactor some helper functions from take_along_axis that can be shared with put_along_axis
  2. implement put_along_axis in JAX
  3. default its argument inplace to True and raise an error for changing array in place)

riven314 avatar May 28 '22 15:05 riven314

Any update regarding this ?

jdeschena avatar Jul 20 '23 01:07 jdeschena

No, but note that you should be able to use jnp.ndarray.at[] to do anything that you might do with put_along_axis.

jakevdp avatar Jul 20 '23 14:07 jakevdp

Assuming I have an implementation functionally equivalent to the numpy version, not in place, but returning a new array, would you be interesting in me making a PR ? I know it is doable with only at but sometimes it is convenient to have such a function directly 🥲

jdeschena avatar Jul 20 '23 22:07 jdeschena

Sure, I'd review a put_along_axis PR

jakevdp avatar Jul 20 '23 23:07 jakevdp

I am interested in taking a crack at this if this is something you'd still want. It seems the work for this PR is already well-scoped, but given it's been almost 2 years since I figure no one is working on this.

Btw I am a Googler so CLA is already signed, I am just not on a Jax team so doing this as some volunteer work :)

DCtheTall avatar Apr 23 '24 19:04 DCtheTall

Sounds good - feel free to take a look!

jakevdp avatar Apr 23 '24 20:04 jakevdp