equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Bug: state invalidation does not work across JIT boundaries

Open matthieumeo opened this issue 5 months ago • 4 comments

Hello @patrick-kidger !

Thanks for the amazing work with Equinox, I'm a huge fan and user 👍🏻

I would like to report a bug with state invalidation across the JIT boundaries: if I jit-compile a function that uses the method equinox.nn.State.set() on a state object then the state object is not correctly marked as invalid/old, and could potentially be reused multiple times.

Looking at the source code, I believe this is due to the fact that JAX's tracing simply drops the mutation step: self._state = _sentinel, hence forbidding Equinox to flag this state as already consumed and hence invalid.

Here is an MRE, using the example of the stateful API:

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array
import jax

class Counter(eqx.Module):
    index: eqx.nn.StateIndex

    def __init__(self):
        init_state = jnp.array(0)
        self.index = eqx.nn.StateIndex(init_state)
   
    @jax.jit
    def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
        value = state.get(self.index)
        new_x = x + value
        new_state = state.set(self.index, value + 1)
        return new_x, new_state

counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)

_, new_state = counter(x, state)
print(state) # State(0x77dd10d000d0=weak_i32[])
print(state._state) # {0: Array(0, dtype=int32, weak_type=True)}

If I remove the @jax.jit decorator in the example above, I get the intended behavior:

print(state) # State(~old~)
print(state._state) # _Sentinel()

I am not sure if this can be fixed as side effects like that are anyway breaking the pure function paradigm assumed by JAX and more specifically high-order transforms like jit. Does this mean we should not jit-compile function involving state objects? Or we can do it but then we loose the Equinox's built-in safeguard re. reuse of previous state? I tend to think the latter, but just want to double check.

Looking forward, I think it might be beneficial to get rid of this safeguard, to make sure the behavior is consistent with/without jit. Happy to draft a PR if you agree!

Thanks a lot for the help on this.

matthieumeo avatar Aug 13 '25 08:08 matthieumeo

FYI, here are the versions of JAX/Equinox used for the MRE above:

>>> jax.__version__
'0.5.3'
>>> eqx.__version__
'0.11.12'

matthieumeo avatar Aug 13 '25 08:08 matthieumeo

So the behavior here is actually intentional, in that:

  1. the safeguard is important to catch a common class of error: forgetting to use the output from .set.

  2. flattening+unflattening offers a way to clone the state of really desired.

That said, maybe we could do something a bit better here: put the flag in a box, and put that box in the static metadata. Then using .set would propagate even across flattening+u flattening boundaries (such as JIT). We could then add an explicit .clone() method to handle the case in which cloning is desired.

WDYT?

patrick-kidger avatar Aug 14 '25 07:08 patrick-kidger

Hi again!

Thanks for the quick reply. I think I understand your point, and I like the idea. Just to double-check: is this roughly what you had in mind?

(⚠️ just a sketch to convey the gist — not runnable as-is)

def _raise_state_error():
    raise ValueError(_state_error)

class Flag:
    def __init__(self, value):
        self.flag = value

class State(Module):
    _state: _Sentinel | dict[object | int, Any]
    _consumed: Flag = field(static=True, default=False)  # mutable box in static metadata

    def __init__(self, model):
        state = {}
        leaves = jtu.tree_leaves(model, is_leaf=_is_index)
        for leaf in leaves:
            if _is_index(leaf):
                if isinstance(getattr(leaf, "init", _Sentinel()), _Sentinel):
                    raise ValueError(
                        "Do not call `eqx.nn.State(model)` directly. "
                        "Use `eqx.nn.make_with_state(ModelClass)(...)` instead."
                    )
                state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init)
        self._state = state

    def set(self, item: StateIndex[_Value], value: _Value) -> "State":
        if self._consumed.flag:
            # use callback so the error is raised at *runtime* under JIT
            jax.debug.callback(_raise_state_error)
        ...  # build new state as before
        new_self = object.__new__(State)
        new_self._state = state
        self._consumed.flag = True  # flip flag on this instance
        return new_self

    def get(self, item: StateIndex[_Value]) -> _Value:
        if self._consumed.flag:
            # likewise, enforce at runtime under JIT
            jax.debug.callback(_raise_state_error)
        ...
        return self._state[item.marker]

Do we agree this would work because the mutation is happening on a static field, so JAX's tracer will not drop it? How safe/fragile is that given JAX's FP paradigm? Also I think for full JIT compatibility, I need jax.debug.callback no?

If I've understood correctly, I'd be happy to draft a proper diff along these lines. Does this capture what you had in mind?

matthieumeo avatar Aug 20 '25 16:08 matthieumeo

I've been noodling on this idea, and I've realised a potential problem with it.

Namely, that there is a lot of JAX code out there that basically relies on the fact that calling f(x) twice will give you the same answer each time. For example, recomputing from checkpoints.

We'd be breaking that invariant here: by introducing this Flag then we'd be getting some pretty spooky action-at-a-distance, where the internals of f (using state.set) result in mutating the x (the state), and thus prevent you from calling f(x) again!

This would almost certainly break a lot of JAX code, mostly in difficult-to-reason-about edge cases.

For that reason I'm afraid that I think my current conclusion is that the status quo is as good as it gets. :(

patrick-kidger avatar Aug 24 '25 12:08 patrick-kidger