RFC: add APIs for setting elements via an array of indices (i.e., put, put_along_axis, etc)
Copied and adapted from @kgryte's proposal at gh-177 per https://github.com/data-apis/array-api/issues/177#issuecomment-2883011323
Proposal
Add APIs for setting elements via an array of indices.
Motivation
Currently, the array API specification does not provide a direct means of setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or put APIs.
The main argument is that Indexing does not currently support providing a array of indices to index into an array. The principal reason for not supporting fancy indexing is array_api_extra.at.set demonstrates that it is often sufficient to mutate where possible and create a copy with the specified updates otherwise.
Background
The following table summarizes library implementations of such APIs:
| op | NumPy | CuPy | JAX.numpy | Torch | Tensorflow | Dask.array | Ndonnx |
|---|---|---|---|---|---|---|---|
| setting elements along axis | put |
put |
put |
scatter_? |
scatter_nd? |
? | |
| setting elements over matching 1d slices | put_along_axis |
put_along_axis |
put_along_axis |
scatter_? See pytorch/pytorch#120209. |
scatter_nd? |
? |
Discussion about these function in gh-177 concluded with https://github.com/data-apis/array-api/issues/177#issuecomment-1514155595, especially:
The JAX issue is most difficult to resolve (can be done, but a lot of work still to deal with read-only views or similar), but the lack of API uniformity makes this a hard sell in general.
This seems to be resolved. jax.numpy.put and jax.numpy.put_along_axis are implemented; they just return modifyied copies rather than mutating the array in place.