equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Jitting Callback Function

Open sgstepaniants opened this issue 9 months ago • 7 comments

I have a question concerning speeding up my callback function which plots several solutions to a neural ODE. Because evaluating the neural ODE model takes a bit of time (and I make several plots), the callback function ends up taking more time to plot 3 neural ODE predictions than it takes to run through one batch of 30 trajectories in a training step. I think this is because my callback function or model is not jitted. Sticking with equinox and diffrax best practices, I am currently only jitting the functions "make_step" and "grad_loss". So currently my callback function which I call after every batch training step is unjitted and looks like this:

`def callback(steps, key, selected_inds=jnp.array([]), num_random=0): ymin = min(y_full) ymax = max(y_full)

random_inds = jnp.arange(y_train.shape[0])
random_inds = jnp.setdiff1d(random_inds, selected_inds)
random_inds = jr.choice(key, random_inds, shape=(num_random,), replace=False)
inds = jnp.concatenate([selected_inds, random_inds])

fig, axes = plt.subplots(1, len(inds), figsize=(20, 4))
for idx, ax in enumerate(axes):
    i = int(inds[idx])
    tpred, ypred = model(steps, y_train[i, :, None])
    "some plotting code ..."

plt.tight_layout()
plt.show()`

Is there a convention for how to jit callbacks that do plotting? One reason why I am not clear on how to proceed, is that the callback can evaluate the neural ODE model n times to make n plots, and I've heard from JAX guidelines that running model evaluations in a for loop can be slow if I @eqx.filter_jit my model call function which I think is unadvised.

sgstepaniants avatar Apr 09 '25 18:04 sgstepaniants

You should separate your numerical JAX bit and the general 'software' bit of plotting. Do all JAX operations inside a jit'd region, then pass their output to the rest of your program.

patrick-kidger avatar Apr 09 '25 21:04 patrick-kidger

I see, so the only numerical operation I have in the callback is when I call my equinox model "call" function in a for loop. Should I jit the "call" function that is defined in the equinox model class? I didn't see any examples doing this so I avoided taking this route.

sgstepaniants avatar Apr 09 '25 21:04 sgstepaniants

Just wanted to bump this question again. If we want a callback that efficiently plots the results of our model every epoch or so, should we be jitting the call methods in our equinox model class?

sgstepaniants avatar Apr 16 '25 22:04 sgstepaniants

You could decorate __call__ with jit, but more commonly you instead want to jit your loss function and make_step function that applies the optimiser updates.

For plotting, you could create an outer loop that is just Python - looping over epochs - which then has an inner loop over batches. At the end of this outer loop, you can extract some value and create a plot. This setup allows you to compile the expensive stuff and still generate plots, without requiring the use of a callback.

johannahaffner avatar Apr 17 '25 07:04 johannahaffner

That is roughly what I do, I only use my callback to make plots every epoch, with many batch steps in between. It is the "extract some value" step that requires a simulation of my model for plotting. So basically you're saying if I am calling a callback for plotting relatively infrequently (e.g. every 100 batch steps), then it's not worth the hassle tying to jit this function which would require me to somehow jit my model which is a headache outside of make_step which is a headache.

sgstepaniants avatar Apr 17 '25 19:04 sgstepaniants

Yes, that sounds like more trouble than it's worth! You could probably do something like:

@eqx.filter_jit 
def loss(model, data):
      ...

@eqx.filter_jit
def make_step(...)
      ...

for epoch in epochs:
      # Train across batches, get updated model (this is where make_step gets called)
      current_loss = loss(updated_model, data)
      # Either plot something or just record the loss to plot something later

And so on. The plotting function would anyway not be something for which you can expect any kind of performance gain by placing it in a jitted region with a callback-workaround. I never do this - and I'm not even sure it could be done reliably: https://docs.jax.dev/en/latest/external-callbacks.html#flavors-of-callback

johannahaffner avatar Apr 17 '25 20:04 johannahaffner

Yes, that sounds like more trouble than it's worth! You could probably do something like:

@eqx.filter_jit def loss(model, data): ...

@eqx.filter_jit def make_step(...) ...

for epoch in epochs: # Train across batches, get updated model (this is where make_step gets called) current_loss = loss(updated_model, data) # Either plot something or just record the loss to plot something later And so on. The plotting function would anyway not be something for which you can expect any kind of performance gain by placing it in a jitted region with a callback-workaround. I never do this - and I'm not even sure it could be done reliably: https://docs.jax.dev/en/latest/external-callbacks.html#flavors-of-callback

i think jitting loss function is not necessary as loss function call is inside make_step and make step is jitted. it will be twice jitted if done so , even if want to jit loss function then it must be jax.jit(jax.grad) rather than jax.grad(jax.jit)
i think jitting outermost function is necessary for optimal performance

ak24watch avatar Aug 04 '25 08:08 ak24watch