equinox icon indicating copy to clipboard operation
equinox copied to clipboard

error_if edge case

Open alex-forster opened this issue 2 months ago • 5 comments

Came across a case where I expected eqx.error_if to throw an error, but didn't. Minimal reproduction:

def f(arr: Float[Array, "n d"]):

    def access(state: int):
        state = eqx.error_if(state, state >= arr.shape[0], "State exceeds array bounds.")
        return arr[state]

    def step(carry, i):
        state = carry
        value = access(state)
        state += 1
        return state, value

    init_state = 0
    final_state, values = jax.lax.scan(step, init_state, None, length=3)
    return values

Then

arr = jnp.array([[7.0, 13.0]])
f(arr)

gives

Array([[ 7., 13.], [ 7., 13.], [ 7., 13.]], dtype=float32),

whereas I would've expected it to throw an error when access is run for the second time.

Compare this to

arr = jnp.array([[7.0, 13.0], [1.0, 2.0]])
f(arr)

which raises equinox._errors._EquinoxRuntimeError: State exceeds array bounds.

alex-forster avatar Nov 06 '25 14:11 alex-forster

Thank you for the report! I could reduce this to a JAX-only MWE, and opened an issue.

johannahaffner avatar Nov 06 '25 17:11 johannahaffner

Hey there! I wanted to come back to this and offer a little extra insight. I actually don't think it's due to funny callback behaviour, as in the JAX issue linked above. Rather, it's due to an odd mix of compiler optimizations all triggering at the same time.

First of all, if we use Equinox's debug tools, then we can see that the output of error_if actually gets eliminated via dead code elimination:

import equinox as eqx
import jax
import jax.numpy as jnp

def f(arr):
    def access(state: int):
        state = eqx.error_if(state, state >= arr.shape[0], "State exceeds array bounds.")
        state = eqx.debug.store_dce(state)
        return arr[state]

    def step(state, i):
        value = access(state)
        state += 1
        return state, value

    init_state = 0
    final_state, values = jax.lax.scan(step, init_state, None, length=3)
    return values

arr = jnp.array([[7.0, 13.0]])
f(arr)

eqx.debug.inspect_dce()
# Found 1 call to `equinox.debug.store_dce`.
# Entry 0:
# <DCE'd>

i.e. the output state = eqx.error_if(...) is unused, and as such the whole lot gets deleted by the compiler. (The rules for error_if are that it won't prevent dead code elimination.) This detail is why we always have to tell people to use the output of error_if, otherwise it will definitely get eliminated as dead code.

Now when we look at this, it might seem a bit surprising that this can be DCE'd. After all, the state = eqx.error_if(...) is used, to index into arr[state]. What gives?

I believe what happens is that JAX/XLA actually unconditionally unrolls the first few steps of the scan and compiles them separately. (I think this is a detail to make implementing some of their own edge cases easier.) That number of steps is longer than arr.shape[0], and what happens is that the compiler realises that the last arr[state] access is definitely out-of-bounds... so it can clip that access to the last value of the array (normal JAX behaviour for OOB indexing)... so it knows that it doesn't need to know the value of the input state computed in state = eqx.error_if(...)... and so now that state can now be DCE'd 🤦 . And the error never triggers.

By contrast, if you set arr = jnp.arange(8) and length=9 then the error will trigger, because that's longer then the first steps of that compiler unrolling.


If you want a fix for specifically this case, then simply pass on state = eqx.error_if(...) out of the body function as the loop carry. (Right now the loop carry doesn't pass through the error_if at all, as the error_if is only used inside access.) If you do so then your code will raise at runtime as desired :)

A fix for the general case is probably harder... off the top of my head I'm not sure how, as this is solidly in the land of weird-JAX-internals.

I hope that helps!

patrick-kidger avatar Dec 05 '25 00:12 patrick-kidger

Haha that is a funny combination of things to run into! Thanks for the in-depth explanation, Patrick.

What do you think about switching to io_callback instead, as suggested by Jake in the issue above? Would the runtime error machinery become more robust to situations like these or could things still be DCE'd just the same?

johannahaffner avatar Dec 05 '25 10:12 johannahaffner

So io_callback would block DCE entirely, which would frequently be undesirable:

# loads of code
diffeqsolve(...)  # output unused
# loads more code

^ using io_callback would mean that the diffeqsolve has to be performed even though the output is unused – just to check if would raise an error.

patrick-kidger avatar Dec 05 '25 13:12 patrick-kidger

Makes sense!

johannahaffner avatar Dec 05 '25 17:12 johannahaffner