error_if edge case
Came across a case where I expected eqx.error_if to throw an error, but didn't. Minimal reproduction:
def f(arr: Float[Array, "n d"]):
def access(state: int):
state = eqx.error_if(state, state >= arr.shape[0], "State exceeds array bounds.")
return arr[state]
def step(carry, i):
state = carry
value = access(state)
state += 1
return state, value
init_state = 0
final_state, values = jax.lax.scan(step, init_state, None, length=3)
return values
Then
arr = jnp.array([[7.0, 13.0]])
f(arr)
gives
Array([[ 7., 13.], [ 7., 13.], [ 7., 13.]], dtype=float32),
whereas I would've expected it to throw an error when access is run for the second time.
Compare this to
arr = jnp.array([[7.0, 13.0], [1.0, 2.0]])
f(arr)
which raises equinox._errors._EquinoxRuntimeError: State exceeds array bounds.
Thank you for the report! I could reduce this to a JAX-only MWE, and opened an issue.
Hey there! I wanted to come back to this and offer a little extra insight. I actually don't think it's due to funny callback behaviour, as in the JAX issue linked above. Rather, it's due to an odd mix of compiler optimizations all triggering at the same time.
First of all, if we use Equinox's debug tools, then we can see that the output of error_if actually gets eliminated via dead code elimination:
import equinox as eqx
import jax
import jax.numpy as jnp
def f(arr):
def access(state: int):
state = eqx.error_if(state, state >= arr.shape[0], "State exceeds array bounds.")
state = eqx.debug.store_dce(state)
return arr[state]
def step(state, i):
value = access(state)
state += 1
return state, value
init_state = 0
final_state, values = jax.lax.scan(step, init_state, None, length=3)
return values
arr = jnp.array([[7.0, 13.0]])
f(arr)
eqx.debug.inspect_dce()
# Found 1 call to `equinox.debug.store_dce`.
# Entry 0:
# <DCE'd>
i.e. the output state = eqx.error_if(...) is unused, and as such the whole lot gets deleted by the compiler. (The rules for error_if are that it won't prevent dead code elimination.) This detail is why we always have to tell people to use the output of error_if, otherwise it will definitely get eliminated as dead code.
Now when we look at this, it might seem a bit surprising that this can be DCE'd. After all, the state = eqx.error_if(...) is used, to index into arr[state]. What gives?
I believe what happens is that JAX/XLA actually unconditionally unrolls the first few steps of the scan and compiles them separately. (I think this is a detail to make implementing some of their own edge cases easier.) That number of steps is longer than arr.shape[0], and what happens is that the compiler realises that the last arr[state] access is definitely out-of-bounds... so it can clip that access to the last value of the array (normal JAX behaviour for OOB indexing)... so it knows that it doesn't need to know the value of the input state computed in state = eqx.error_if(...)... and so now that state can now be DCE'd 🤦 . And the error never triggers.
By contrast, if you set arr = jnp.arange(8) and length=9 then the error will trigger, because that's longer then the first steps of that compiler unrolling.
If you want a fix for specifically this case, then simply pass on state = eqx.error_if(...) out of the body function as the loop carry. (Right now the loop carry doesn't pass through the error_if at all, as the error_if is only used inside access.) If you do so then your code will raise at runtime as desired :)
A fix for the general case is probably harder... off the top of my head I'm not sure how, as this is solidly in the land of weird-JAX-internals.
I hope that helps!
Haha that is a funny combination of things to run into! Thanks for the in-depth explanation, Patrick.
What do you think about switching to io_callback instead, as suggested by Jake in the issue above? Would the runtime error machinery become more robust to situations like these or could things still be DCE'd just the same?
So io_callback would block DCE entirely, which would frequently be undesirable:
# loads of code
diffeqsolve(...) # output unused
# loads more code
^ using io_callback would mean that the diffeqsolve has to be performed even though the output is unused – just to check if would raise an error.
Makes sense!