equinox icon indicating copy to clipboard operation
equinox copied to clipboard

`filter_jit` cache missing issues when using non-array variables with an `equinox` class

Open maxwest-uw opened this issue 5 months ago • 1 comments

as part of the scarlet2 project we use equinox classes to represent various types of data structures. we ran into some issues when running jobs with a larger amount of data than is our typical use case where our fitting methods (where more computationally expensive calculations are done using equinox.filter_jit) were running really slowly. we eventually determined that the slowdown was being caused by the filter_jit function missing the cache repeatedly, even for very simple operations.

  • specifically we determined that this slowdown was due to equinox not knowing the proper way to handle non-array variables that change as part of an equinox/jax operation
  • when we changed the variables we were attempting to change as part of the operation from single integers into arrays, the compilation behaved as expected and ran at high speeds
  • when @b-remy rolled the computations out of equinox into pure jax the operation was also plenty fast and wasn't hitting the cache miss for every operation.
  • when we run the problematic code a second time after, compilation has occurred for each individual iteration of the loop with each given step of the loop. as @pmelchior says: "jit is recompiling the _step function when the argument goes from MyModule(x=0) to MyModule(x=1) and so on."

it's entirely possible that this slowdown issue is something that should be better handled on our end as we're using equinox in a non-standard way, but we figure that our setup is not so non-standard that there aren't other teams who are using a similar setup and facing slowdowns as well, so we figured we'd report this :)

Reproduction

this is our toy, "minimal unviable product" that we used to test. based around the structure we used for the much more complicated scene class and computational structure, where we were seeing the slowdown.

import equinox as eqx
import jax
import numpy as np

# inner computation, basically just adds one to the variable and updates
@eqx.filter_jit
def _step(scene, x):
    updates = jax.tree_util.tree_map(lambda s: s + 1, scene)
    return updates

class MyModule(eqx.Module):
    # same way variables are defined in `Scene`
    x : np.array
    def __init__(self):
        self.x = 0

    def fit(self, steps):
        scene_ = self
        for _ in range(steps):
            # sort of supposed to mimic the `_make_step` method
            scene_ = _step(scene_, 1)
        return scene_

mymod = MyModule()
# this operation takes about 20 seconds to add 1 to a variable 3000 times
s = mymod.fit(3000)
print(s.x)  # should print 3000

if we recreate the structure above, but place MyModule.x into an array of size 1:

import equinox as eqx
import jax
import numpy as np

@eqx.filter_jit
def _step(scene, x):
    updates = jax.tree_util.tree_map(lambda s: s + 1, scene)
    return updates

class MyModule(eqx.Module):
    # same way variables are defined in `Scene`
    x : np.array
    def __init__(self):
        self.x = np.zeros(1)

    def fit(self, steps):
        scene_ = self
        for _ in range(steps):
            # sort of supposed to mimic the `_make_step` method
            scene_ = _step(scene_, 1)
        return scene_

mymod = MyModule()
# this now takes 100 ms
s = mymod.fit(3000)
print(s.x)  # should print 3000

maxwest-uw avatar Aug 01 '25 18:08 maxwest-uw

Hey there! It's great to hear about how you're using Equinox.

As for the caching you're seeing, this is expected behaviour: equinox.filter_jit caches based on array shape+dtype, and on the value of all non-arrays. In this case you have a non-array that is changing, and so this is triggering recompilation.

You can use eqx.debug.assert_max_traces to catch this and explicitly error out.

This is intentionally different from jax.jit (it's what the filter_ refers to), in that jax.jit will convert bool | int | float | complex into arrays when caching/tracing. We avoid doing this as it is common to use these values as e.g. boolean flags, and tracing them would mean that statements like if flag: will fail. If tracing based on those values is desired, then a user can simply wrap them in arrays. In this way we aim to be a bit more flexible.

patrick-kidger avatar Aug 01 '25 23:08 patrick-kidger