diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Old versions much faster?

Open cgiovanetti opened this issue 1 month ago • 4 comments

I have a large project that's intended to run on CPU--during development it was running in well under a second per evaluation, but now I find it runs almost twice as slow.

I have more or less pinned down the issue to be a problem with my JAX version--0.4.26 and 0.4.28 runs fast, and everything later is slower. However, in creating an MWE I can only reproduce if I use diffrax, and so I wonder if there's some interaction with later versions causing a slowdown. I also opened a related issue on the JAX GitHub.

Here is the MWE:

import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt

import time

def fun(rtol=1e-8, atol=1e-10,solver=Tsit5()): 
    T_EM_init = 8.6
    rho_extra_init = 830.

    Y0 = (0., T_EM_init)

    sol = diffeqsolve(
        ODETerm(dY), solver, args=(rho_extra_init),
        t0 = 0., t1=100., dt0=None, y0=Y0, 
        saveat=SaveAt(steps=True), 
        stepsize_controller = PIDController(
            rtol=rtol, atol=atol
        ), 
        max_steps=512
    )

    a_vec = jnp.exp(sol.ys[0])

    return (
        a_vec
    )

def dY(t, Y, args): 
    lna, T_g = Y
    rho_extra_init = args

    rho_EM = T_g**4
    rho_extra = rho_extra_init * 1. / jnp.exp(lna)**4 

    H = (rho_EM + rho_extra)**0.5
    drho_EM_dt = -3 * H * rho_EM
    dT_g_dt = drho_EM_dt / (4*T_g**3)

    return H, dT_g_dt

for i in range(5):
    start = time.time()
    a_vec = jax.block_until_ready(jax.jit(fun)())
    print(time.time() - start)

Running with jax/jaxlib==0.4.28 and diffrax==0.6.0, each compiled iteration runs in ~0.00015s on an M1 mac. Running with either jax/lib==0.6.2 or 0.8.1 and diffrax==0.7.0, each compiled iteration runs in 0.0004s on the same hardware. I suspect this is not directly an issue with newer versions of diffrax, because I also see a slowdown in my production code with jax/lib==0.4.29 and diffrax==0.6.0, but maybe something second-order in the way diffrax is drawing on JAX?

In absolute terms it's not much, but I suspect this is also translating to my factor-of-two slowdown in my actual production code. Incidentally I find that if my differential equation is a function of only one variable--i.e., I track only H and not dT_g_dt, I find the compiled code runs faster on newer versions. Using eqx.filter_jit doesn't seem to help much in any case.

cgiovanetti avatar Dec 02 '25 20:12 cgiovanetti

This is disappointingly expected, I'm afraid. JAX has made a number of performance pessimisations over the past year or so.

I think @lockwo had collected some of these examples / has bumped into many of these as well, and might be able to give more specific details.

It might be worth us making sure that modern Diffrax remains compatible with older JAX releases, for the sake of performance...

patrick-kidger avatar Dec 02 '25 20:12 patrick-kidger

Thanks for the fast reply! Sticking to old jax distributions will work in some cases but in general isn't the ideal workaround for me--I'm working on a mixed CPU/GPU project and really struggle to get JAX with CUDA to run for some of these older versions. @lockwo if you have other examples to share maybe the JAX team might be able to help upstream?

cgiovanetti avatar Dec 02 '25 23:12 cgiovanetti

Most of my examples were CPU focused (although there was some GPU work in https://github.com/jax-ml/jax/issues/20968). What usually solved my issues was simply disabling the new CPU thunk runtime for XLA, while the runtime has improved I don't think it's as good still (https://github.com/jax-ml/jax/issues/30554). A common general problem pattern I encountered was for double while loops of varying lengths (which you have in diffrax when adaptive step sizes are used), so that might be something to check/inform a pure jax replication.

lockwo avatar Dec 03 '25 17:12 lockwo

Interesting--this cuts my runtime way down for jax==0.6.2, and has virtually no effect for jax==0.8.1. Thanks for the insight! I'll see if I can get an MWE with a double while loop

cgiovanetti avatar Dec 03 '25 17:12 cgiovanetti