jax
jax copied to clipboard
jax.numpy array indexing has different out-of-bounds behavior to numpy
import jax
import jax.numpy as np
x = np.arange(10)
x = jax.device_put(x)
print(x[[13]])
This prints [3], but it should actually throw an out of bounds error like the original NumPy.
P.S.: why does np.arange return a host array? Is this intended behavior or shouldn't it rather behave like np.array and return a device array?
This comment is stale! Please instead see this page of the Sharp Bits notebook.
We should probably document this behavior difference, i.e. that fancy indexing doesn't do bounds-checking, though I'm not sure how feasible it is to solve. I believe it's hard to efficiently check accelerator (e.g. GPU) error conditions, and moreover XLA doesn't have facilities for it (so I guess we'd have to generate our own error-checking code, like wrapping everything in a Maybe monad, which seems impractical).
As you probably noticed, non-fancy indexing like this raises an error:
import jax
import jax.numpy as np
x = np.arange(10)
x = jax.device_put(x)
print(x[13])
~But fancy indexing, like in your example, and "dynamic" indexing like this don't raise errors and instead wrap (indexing modulo the length of the array):~ See this Sharp Bits section on out-of-bounds indexing behavior and now it clamps.
@jax.jit
def f(x, i):
return x[i]
f(np.arange(10), 13)
There might be room for improvement here, but we probably can't match the original NumPy's out-of-bounds behavior in general. I'm curious what other libraries do. (Any wisdom, @hawkinsp?)
P.S.: why does np.arange return a host array? Is this intended behavior or shouldn't it rather behave like np.array and return a device array?
We just haven't yet implemented one that generates a device array. In general I think jax.numpy reserves the right to return regular numpy.ndarray values or DeviceArray values, as the latter is meant to be just a "lazy, device-persistent" optimization. But as you've shown here, the story isn't quite so clean: when there are errors like out-of-bounds indexing, whether you get a DeviceArray or an ndarray can affect behavior.
You can generate device-backed arange-like arrays using lax.iota:
from jax import lax
print type(3 + lax.iota(np.int32, 10))
For this specific issue, I think we should consider this a provide-documentation-of-different-behavior-under-error-conditions case. Let me know if you disagree, but otherwise let's leave this open to track that documentation enhancement.
Providing some kind of documentation of the differences would be great and would solve the issue for me. Right now it’s hard to know what’s expected behavior and what’s not.
FWIW, I think the monadic Maybe solution wouldn't be that painful to implement, if done carefully. I suspect it's our best bet in the short term.
Notably, even non-fancy indexing doesn't raise a bounds error anymore (it wraps instead). Using the example in https://github.com/google/jax/issues/278#issuecomment-457829164:
import jax
import jax.numpy as np
x = np.arange(10)
x = jax.device_put(x)
print(x[13]) # 3
For reference, TensorFlow returns 0 on out-of-bounds indexing: https://www.tensorflow.org/api_docs/python/tf/gather
This has led to a pretty nasty bug for me. This kind of silent failure seems quite dangerous.
I'm not quite sure I understand how checking would be slower since it seems to be that under-the-hood there's some returned_ix = min(ix, len(xs) - 1) calculation happening anyway. Barring that, it would be nice to have a feature analogous to JAX_DEBUG_NANS that would run with standard bounds checking enabled.
See also https://github.com/google/jax/issues/1451.
I'm not quite sure I understand how checking would be slower since it seems to be that under-the-hood there's some
returned_ix = min(ix, len(xs) - 1)calculation happening anyway
The problem is that pausing execution to check for errors can be expensive, especially on accelerators, because control flow all happens on the CPU.
Even on CPUs bounds checks can slow things down, e.g., note this guidance from Cython: https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html#tuning-indexing-further
That said, I agree that we should absolutely have a debug mode here. And default behavior of returning NaNs for out of bounds indexing results would be safer than silent wrapping.
The problem is that pausing execution to check for errors can be expensive, especially on accelerators, because control flow all happens on the CPU.
This makes sense to me in principle, but I guess I'm not sure how bounds checking ends up being more expensive than returned_ix = min(ix, len(xs) - 1). Aren't they roughly isomorphic? Both end up branching on ix.
That said, I agree that we should absolutely have a debug mode here. And default behavior of returning NaNs for out of bounds indexing results would be safer than silent wrapping.
Totally agree! Hadn't considered returning NaNs.
This makes sense to me in principle, but I guess I'm not sure how bounds checking ends up being more expensive than
returned_ix = min(ix, len(xs) - 1). Aren't they roughly isomorphic? Both end up branching onix.
The most expensive part -- or maybe just a little tricky to do efficiently -- is raising an error back in Python. TensorFlow does this on the CPU but not the GPU, for example.
I also suspect it may be more efficient for XLA to do silent wrapping since it needs to support negative indices already. Instead of using returned_ix = min(ix, len(xs) - 1) and doing a separate bounds check on returned_ix, it could just use returned_ix = ix % len(xs).
The most expensive part -- or maybe just a little tricky to do efficiently -- is raising an error back in Python. TensorFlow does this on the CPU but not the GPU, for example.
Ah ok, I see what you mean now.
I also suspect it may be more efficient for XLA to do silent wrapping since it needs to support negative indices already. Instead of using returned_ix = min(ix, len(xs) - 1) and doing a separate bounds check on returned_ix, it could just use returned_ix = ix % len(xs).
Yeah this would def make sense. The weird thing is that XLA doesn't seem to do any wrapping. It just returns the last element AFAICT.
Is there a consensus on this yet? The suggestion of a debug mode & documentation would really solve the issue of stumbling into this without realising it, which is what happened to me! Additional runtime nan seems like a good suggestion too, though I admittedly do not know whether it is feasible.
Is it possible to add some notice in documentation or throw a warning in runtime? I heavily relied on this out-of-bound error behavior as it is in NumPy, PyTorch, and Python itself. I was not very careful in indexing and expected that a program fails fast. Now I'm not sure anymore that every piece of code, I've ever written in JAX, is correct. 😞
Sorry this has caused problems for you! The out-of-bounds indexing issue is discussed in the documentation here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing Do you have suggestions for a better place to surface the issue in the docs?
Regarding a runtime warning: that unfortunately is not possible for the same reason that a runtime error cannot be raised; see the discussion at the above link.
You might find jax.experimental.checkify useful for catching OOB checks.
Do you have suggestions for a better place to surface the issue in the docs?
I have many times read Sharp Bits page but I have never noticed OOB section. So, may be it would be better to move it closer to Random Numbers section (e.g. after the section)? It is good place to draw attention since everybody has some issues with RNGs.
You might find jax.experimental.checkify useful for catching OOB checks.
It looks usefull. I'll definitely try. Thanks!