Bug: state invalidation does not work across JIT boundaries
Hello @patrick-kidger !
Thanks for the amazing work with Equinox, I'm a huge fan and user 👍🏻
I would like to report a bug with state invalidation across the JIT boundaries: if I jit-compile a function that uses the method equinox.nn.State.set() on a state object then the state object is not correctly marked as invalid/old, and could potentially be reused multiple times.
Looking at the source code, I believe this is due to the fact that JAX's tracing simply drops the mutation step: self._state = _sentinel, hence forbidding Equinox to flag this state as already consumed and hence invalid.
Here is an MRE, using the example of the stateful API:
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array
import jax
class Counter(eqx.Module):
index: eqx.nn.StateIndex
def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)
@jax.jit
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
value = state.get(self.index)
new_x = x + value
new_state = state.set(self.index, value + 1)
return new_x, new_state
counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)
_, new_state = counter(x, state)
print(state) # State(0x77dd10d000d0=weak_i32[])
print(state._state) # {0: Array(0, dtype=int32, weak_type=True)}
If I remove the @jax.jit decorator in the example above, I get the intended behavior:
print(state) # State(~old~)
print(state._state) # _Sentinel()
I am not sure if this can be fixed as side effects like that are anyway breaking the pure function paradigm assumed by JAX and more specifically high-order transforms like jit. Does this mean we should not jit-compile function involving state objects? Or we can do it but then we loose the Equinox's built-in safeguard re. reuse of previous state? I tend to think the latter, but just want to double check.
Looking forward, I think it might be beneficial to get rid of this safeguard, to make sure the behavior is consistent with/without jit. Happy to draft a PR if you agree!
Thanks a lot for the help on this.
FYI, here are the versions of JAX/Equinox used for the MRE above:
>>> jax.__version__
'0.5.3'
>>> eqx.__version__
'0.11.12'
So the behavior here is actually intentional, in that:
-
the safeguard is important to catch a common class of error: forgetting to use the output from
.set. -
flattening+unflattening offers a way to clone the state of really desired.
That said, maybe we could do something a bit better here: put the flag in a box, and put that box in the static metadata. Then using .set would propagate even across flattening+u flattening boundaries (such as JIT). We could then add an explicit .clone() method to handle the case in which cloning is desired.
WDYT?
Hi again!
Thanks for the quick reply. I think I understand your point, and I like the idea. Just to double-check: is this roughly what you had in mind?
(⚠️ just a sketch to convey the gist — not runnable as-is)
def _raise_state_error():
raise ValueError(_state_error)
class Flag:
def __init__(self, value):
self.flag = value
class State(Module):
_state: _Sentinel | dict[object | int, Any]
_consumed: Flag = field(static=True, default=False) # mutable box in static metadata
def __init__(self, model):
state = {}
leaves = jtu.tree_leaves(model, is_leaf=_is_index)
for leaf in leaves:
if _is_index(leaf):
if isinstance(getattr(leaf, "init", _Sentinel()), _Sentinel):
raise ValueError(
"Do not call `eqx.nn.State(model)` directly. "
"Use `eqx.nn.make_with_state(ModelClass)(...)` instead."
)
state[leaf.marker] = jtu.tree_map(jnp.asarray, leaf.init)
self._state = state
def set(self, item: StateIndex[_Value], value: _Value) -> "State":
if self._consumed.flag:
# use callback so the error is raised at *runtime* under JIT
jax.debug.callback(_raise_state_error)
... # build new state as before
new_self = object.__new__(State)
new_self._state = state
self._consumed.flag = True # flip flag on this instance
return new_self
def get(self, item: StateIndex[_Value]) -> _Value:
if self._consumed.flag:
# likewise, enforce at runtime under JIT
jax.debug.callback(_raise_state_error)
...
return self._state[item.marker]
Do we agree this would work because the mutation is happening on a static field, so JAX's tracer will not drop it? How safe/fragile is that given JAX's FP paradigm? Also I think for full JIT compatibility, I need jax.debug.callback no?
If I've understood correctly, I'd be happy to draft a proper diff along these lines. Does this capture what you had in mind?
I've been noodling on this idea, and I've realised a potential problem with it.
Namely, that there is a lot of JAX code out there that basically relies on the fact that calling f(x) twice will give you the same answer each time. For example, recomputing from checkpoints.
We'd be breaking that invariant here: by introducing this Flag then we'd be getting some pretty spooky action-at-a-distance, where the internals of f (using state.set) result in mutating the x (the state), and thus prevent you from calling f(x) again!
This would almost certainly break a lot of JAX code, mostly in difficult-to-reason-about edge cases.
For that reason I'm afraid that I think my current conclusion is that the status quo is as good as it gets. :(