jax
jax copied to clipboard
jax.numpy: put_along_axis
Hi :), I am using jax.numpy to define a NumPyro model. To do so it would be very useful to have a function doing essentially the same as numpy's put_along_axis. In fact, for now I modified numpy's code slightly to output a jax object. But I was wondering, if it is planned to add this feature to jax.numpy in the future?
We haven't implemented this because the semantics of np.put_along_axis
are to modify the array in-place, and this is not possible in JAX because JAX arrays are immutable. I suspect you could accomplish what you want to do using some combination of index update operators; see https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at
I see - thanks a lot for the explanation :) I was able to implement a custom version of np.put_along_axis
using the index update operators as you suggested. Thanks again!
As there is some interest in this custom version that I am using for my problem, I am sharing the code here: (Happy to any suggestions on how to solve this better)
"""
Defining put_along_axis() from numpy for jax.numpy.
Essentially copied the code from
https://github.com/numpy/numpy/blob/4adc87dff15a247e417d50f10cc4def8e1c17a03/numpy/lib/shape_base.py#L29
"""
import numpy.core.numeric as _nx
import jax
def _make_along_axis_idx(arr_shape, indices, axis):
# compute dimensions to iterate over
if not _nx.issubdtype(indices.dtype, _nx.integer):
raise IndexError("`indices` must be an integer array")
if len(arr_shape) != indices.ndim:
raise ValueError("`indices` and `arr` must have the same number of dimensions")
shape_ones = (1,) * indices.ndim
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))
# build a fancy index, consisting of orthogonal aranges, with the
# requested index inserted at the right location
fancy_index = []
for dim, n in zip(dest_dims, arr_shape):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
fancy_index.append(_nx.arange(n).reshape(ind_shape))
return tuple(fancy_index)
def custom_put_along_axis(arr, indices, values, axis):
"""
Parameters
----------
arr : ndarray (Ni..., M, Nk...)
Destination array.
indices : ndarray (Ni..., J, Nk...)
Indices to change along each 1d slice of `arr`. This must match the
dimension of arr, but dimensions in Ni and Nj may be 1 to broadcast
against `arr`.
values : array_like (Ni..., J, Nk...)
values to insert at those indices. Its shape and dimension are
broadcast to match that of `indices`.
axis : int
The axis to take 1d slices along. If axis is None, the destination
array is treated as if a flattened 1d view had been created of it.
"""
# normalize inputs
if axis is None:
arr = arr.flat
axis = 0
arr_shape = (len(arr),) # flatiter has no .shape
else:
# axis = normalize_axis_index(axis, arr.ndim)
arr_shape = arr.shape
# use the fancy index
arr = arr.at[tuple(_make_along_axis_idx(arr_shape, indices, axis))].set(values)
return arr
One idea may be to define jax.numpy.put_along_axis
, but add an extra inplace
keyword that defaults to True
, such that the function errors with the default value. Users could set inplace=False
and then have a version of the function that returns the updated array. It would avoid the potential pitfall of users assuming the function works in-place and then being confused why the array isn't changing. What do you think?
I suspect we could refactor things to share much of the index processing with take_along_axis
.
Hi @jakevdp I would like to contribute to this ticket.
I will break down this ticket by the following, see if it makes sense:
- refactor some helper functions from
take_along_axis
that can be shared withput_along_axis
- implement
put_along_axis
in JAX - default its argument
inplace
toTrue
and raise an error for changing array in place)
Any update regarding this ?
No, but note that you should be able to use jnp.ndarray.at[]
to do anything that you might do with put_along_axis
.
Assuming I have an implementation functionally equivalent to the numpy version, not in place, but returning a new array, would you be interesting in me making a PR ? I know it is doable with only at
but sometimes it is convenient to have such a function directly 🥲
Sure, I'd review a put_along_axis
PR
I am interested in taking a crack at this if this is something you'd still want. It seems the work for this PR is already well-scoped, but given it's been almost 2 years since I figure no one is working on this.
Btw I am a Googler so CLA is already signed, I am just not on a Jax team so doing this as some volunteer work :)
Sounds good - feel free to take a look!