array-api
array-api copied to clipboard
RFC: assignment via integer array indexing
[EDIT] this issue is about the lack of xp.put or equivalent __setitem_ semantics for integar array indices. Read comments below.
[original post] The current situation when an index is an array of ints or bools is very messy:
- numpy blindly passes every unexpected object it finds to np.asarray, which makes it accept np.ndarray, lists, tuples, but also memoryviews and anything with an
__array__interface - Sparse accepts lists, tuples, or numpy arrays; other sparse arrays don't work
- PyTorch won't accept numpy arrays of unsigned integers
- PyTorch won't accept PyTorch arrays of integers with dtype other than the native int
- JAX won't accept lists or tuples
- dask won't accept tuples
I think the array API standard should define a reasonable common surface.
xref https://github.com/data-apis/array-api-compat/pull/205#discussion_r1861288136
NOTE: the much more complicated case of multiple fancy indices, e.g. a[[0, 1], [1, 2]] is out of scope for this issue.
Just a quick note that take is the standardized/uniform way of doing indexing with an integer array (xref gh-416 and the issue that PR links to).
You're completely right that it's very messy across libraries - it's just very hard, since we can't mandate that libraries stop doing something and break backwards compatible (or technically we can, but it'll probably just be ignored rather than implemented).
Just a quick note that
takeis the standardized/uniform way of doing indexing with an integer array (xref gh-416 and the issue that PR links to).
I couldn't find put or anything similar though? How do you update along a fancy index?
cross-ref https://github.com/data-apis/array-api/pull/900/ which aims to actually standardize it.
Looks like gh-900 covered this request, and we can close it as completed, I'd think.
@ev-br #900 doesn't cover setitem semantics, however.
Would be great to update the issue title/test then, if we repurpose it for __setitem__ specifically.
Reviving this issue as we are currently facing the limitation in https://github.com/scipy/scipy/pull/23425. Is there any movement on this? The case for __setitem__ semantics with integer arrays seems both sensible and widely implemented. When running the current implementation in https://github.com/scipy/scipy/pull/23425, the only framework that fails to set items is array-api-strict.
Is there any movement on this?
No, no change, sorry.
the only framework that fails to set items is array-api-strict.
I don't think that's possible, JAX doesn't implement __setitem__. It's either not tested there or the SciPy implementation does an out of place update I suspect.
Ah right, then maybe this is a topic for array-api-extra instead. Because this pattern should be common in most frameworks:
import array_api_strict as xp
import jax.numpy as jp
x = jp.array([[1, 2, 3], [4, 5, 6]])
idx = jp.array([0])
# This could be emulated in array-api-extra with xpx.at(x)[idx, ...] = xp.asarray([[7, 8, 9]])
# It is already emulating this case for regular indexing
new_x = x.at[idx, ...].set(jp.array([[7, 8, 9]]))
print(new_x) # [[7 8 9], [4 5 6]]
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
idx = xp.asarray([0])
new_x = x[idx, ...] = xp.asarray([[7, 8, 9]]) # IndexError: Fancy indexing __setitem__ is not supported.
Ah right, then maybe this is a topic for array-api-extra instead.
Indeed: https://data-apis.org/array-api-extra/generated/array_api_extra.at.html
Using the current spec, is there an array API-compatible way to implement this feature though? Because swapping the last line with
import array_api_extra as xpx
...
new_x = xpx.at(x)[idx, ...].set(xp.asarray([[7, 8, 9]])) # IndexError: ...
raises the same error, since array-api-extra tries to translate this into x[idx] = y for array-api-strict.
Using the current spec, is there an array API-compatible way to implement this feature though? Because swapping the last line with
import array_api_extra as xpx ... new_x = xpx.at(x)[idx, ...].set(xp.asarray([[7, 8, 9]])) # IndexError: ... raises the same error, since array-api-extra tries to translate this into
x[idx] = yfor array-api-strict.
It is possible to update xpx.at to detect an integer array index and force an out-of-place operation in that case.
That would be great. Should I open an issue in array-api-extra?
Yes, but I'm not sure about my current bandwidth. I'd suggest taking a stab at the PR yourself if you care about having it in quickly.
Alright, will do if I find the time.
I played around with the issue for a bit and encountered the following situation:
import jax.numpy as jnp
import numpy as np
x = jnp.asarray([0.0, 1.0])
y = jnp.asarray([2.0, 4.0])
idx = jnp.asarray([1, 1])
x = x.at[idx].add(y)
print(x) # [0., 7.] on jax 0.6.0
x = np.asarray([0.0, 1.0])
y = np.asarray([2.0, 4.0])
idx = np.asarray([1, 1])
x[idx] += y
print(x) # [0., 5.] on numpy 2.2.6
To me this suggests that fancy __setitem__ semantics cannot be implemented at the moment because there is no agreement over the correct behavior.
@amacati that is correct. xref gh-24 for lots of examples like that and more context. gh-609 is also relevant.
To me this suggests that fancy
__setitem__semantics cannot be implemented at the moment because there is no agreement over the correct behavior.
Indeed. The only way this can get implemented, is if there are unambiguous semantics for this. In turn, the only way for that to happen is when in-place and out-of-place semantics are identical, i.e. one can guarantee that the += operation modifies only one array (which numpyandcupycurrently cannot). This issue will not move until there is infrastructure at least innumpy` for verifying that (the kind of "view tracking" one would need also for copy-on-write behavior, already discussed in gh-24).
I'm not sure that it is useful to keep this issue open for __setitem__, since it lacks context and is essentially duplicate with the other issues I linked.
Thanks for the thorough explanation and the references. The thought was to add a workaround in array-api-extra, but given the situation that's obviously not possible. One follow-up question: Is the behavior of __setitem__ at least shared for only setting elements of the array? I.e., updates in the style of x[idx] = y? Then one could fail all in-place operations like operator.iadd, but at least support item assignment in xpx.at.
The thought was to add a workaround in array-api-extra, but given the situation that's obviously not possible
I'm not sure I understand - what is missing from the functional array_api_extra.at? It should always be possible to special-case things there so it works for all array libraries.
That should answer the follow-up question as well: item assignment should work there already I'd think.
The thought was to add a workaround in array-api-extra, but given the situation that's obviously not possible
I'm not sure I understand - what is missing from the functional
array_api_extra.at? It should always be possible to special-case things there so it works for all array libraries.That should answer the follow-up question as well: item assignment should work there already I'd think.
Item assignment using xpx.at does not work for integer array indices. All real frameworks succeed, but array-api-strict fails because the standard is not defined for this case. If the result is unique across all frameworks it would make sense to add a workaround for array-api-extra. If not, implementing the workaround seems misleading.
Oh okay, that seems like an annoying but minor practical issue.
If the result is unique across all frameworks it would make sense to add a workaround for array-api-extra.
It seems fine to me here to special-case array-api-strict or to have a generic fallback path for unknown libraries using standard functions like where; I'd think both are possible.
All real frameworks succeed, but array-api-strict fails because the standard is not defined for this case.
If indeed all real frameworks succeed, then it very well could be added to the standard, I suppose. And then array-api-strict happily implements it.
What happens is probably all real frameworks succeed with some additional limitations (at a guess --- no repeated indices?).
I'd suggest to start by writing a hypothesis test as a PR to array-api-tests. This has proven to be very helpful in fishing out edge cases like these.
then it very well could be added to the standard, I suppose
It cannot. "real frameworks succeed" for a custom at function because the implementation is different per framework under the hood. And proposing a new functional at for the standard, something exactly zero libraries currently have, is a nonstarter.
If the result is unique across all frameworks it would make sense to add a workaround for array-api-extra. If not, implementing the workaround seems misleading.
After some more digging, this is what is currently happening for frameworks:
# Behavior for indexing differs between devices:
# CPU: For duplicate indices, the **last** value is used
# GPU: For duplicate indices, the **first** value is used
import cupy as cp
import jax
import jax.numpy as jnp
import numpy as np
import torch
# Numpy: CPU, last value is used
x, y = np.array([0.0, 1.0]), np.array([2.0, 3.0])
idx = np.array([1, 1])
x[idx] = y
print(x) # [0, 3]
# Cupy: GPU, first value is used
x, y = cp.array([0.0, 1.0]), cp.array([2.0, 3.0])
x[cp.array(idx)] = y
print(x) # [0, 2]
# Torch: CPU, last value is used
x, y = torch.tensor([0.0, 1.0]), torch.tensor([2.0, 3.0])
x[torch.tensor(idx)] = y
print(x) # [0, 3]
# Torch: GPU, first value is used
x, y = torch.tensor([0.0, 1.0]).cuda(), torch.tensor([2.0, 3.0]).cuda()
x[torch.tensor(idx).cuda()] = y
print(x) # [0, 2]
# JAX: CPU, last value is used
device = jax.devices("cpu")[0]
x, y = jnp.array([0.0, 1.0], device=device), jnp.array([2.0, 3.0], device=device)
x = x.at[jnp.array(idx, device=device)].set(y)
print(x) # [0, 3]
# JAX: GPU, first value is used
x, y = jnp.array([0.0, 1.0]), jnp.array([2.0, 3.0])
x = x.at[jnp.array(idx)].set(y)
print(x) # [0, 2]
The results for duplicate indices differ depending on the device. What's the overall sentiment on this? One could still implement special casing for array-api-strict in array-api-extra mimicking the numpy behavior, but it feels slightly misleading.
Alternatively, since array-api-strict is eager, we could error out on non-unique indices because of undefined behavior.
Duplicate indices have valid use cases, if you have some calculation that yields (index, value) then you can end up with duplicates and that is fine as long as value is the same for duplicate index-es. Hence, forbidding it doesn't sound right. Just picking a consistent choice in the at implementation for array-api-strict does.
This was/is a common question for scatter type functions. On GPU/accelerator devices, and even on CPU if one would implement the operation with parallelism under the hood, those are inherently data races for duplicate indices.
Summary of https://github.com/data-apis/array-api-extra/pull/395:
We discussed how to include this into array-api-extra and came to the conclusion that it is currently not desirable to implement a workaround for it.
The main reason is that assignments would have to be replaced by boolean masks, which are not as expressive as integers and thus require multiple workarounds. Even focusing just on __setitem__ over the first axis and 1D integer arrays we hit several edge cases that would have required element shuffling and intermediate masks with quadratic size in the index length.
See e.g. https://github.com/data-apis/array-api-extra/pull/395#issuecomment-3234160232.
We also discussed this in the consortium community meeting yesterday and the sentiment was that it is fine to skip for array-api-strict in SciPy for now.
A potential long-term solution would be to be to standardise (some subset of) fancy indexing (or potentially multiple different behaviours if they are competing for standardisation??) as an extension (/extensions) to the standard. Then perhaps array-api-extra could implement somewhat more generic code-paths, rather than per-library workarounds. SciPy would in turn be able to better specify which kinds of libraries it supports.
That said, such an effort seems very low priority right now. We should keep our eyes out for any other situations where we need to skip for array-api-strict and reassess.
A potential long-term solution would be to be to standardise (some subset of) fancy indexing (or potentially multiple different behaviours if they are competing for standardisation??) as an extension (/extensions) to the standard.
We only briefly touched on that, but let me just say that I think that that is unlikely. Fancy indexing is very complex, unintuitive, and hard to write down the rules for. An extension that doesn't match what libraries can implement is low value; having a function that does some of what we want would be much easier to deal with.
We should keep our eyes out for any other situations where we need to skip for array-api-strict and reassess
In SciPy, it'd be great to keep an explicit list of places where we skip array-api-strict. Some skips are technical IIRC, and some are fundamental, like this one. An issue with an overview would be helpful, I think.