os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' no longer working with Jax 0.6.0
Hello
I had experienced some slowdown in runtime using diffrax with Jax 0.5.3 and was recommended using this flag
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' no longer working with Jax 0.6.0
This works great for Jax 0.5.3 giving me the speed from older versions, however with Jax 0.6.0 the script seems to crash when calling the ODE solver. Removing the flag allows my code to run, but noticeably slower than older jax version with the flag. Have other people experienced this and have any suggestions?
Do you have a MWE?
That aside I'd also recommend raising this issue on the JAX issue tracker (with a reproducible example) if you're seeing severe performance drops.
My main use case is for a much larger problem and I have not yet created an example concise enough to post here that fully shoes the runtime difference gained from using "os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false" on jax 0.5.3 as suggested by 606 however I have been able to pinpoint with this example what is causing my segmentation fault error. Using os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false on this example will cause a segmentation fault because of the progress_meter. removing the progress meter has no error with or without the xla flag. However the program fails if the meter and XLA_flag are used together. Here with the Tsit5() solver I observe a runtime of 0.014693737030029297 seconds with the XLA_flag and 0.09974408149719238 without. The difference here is minor, but significant when applied to larger problems which take multiple hours to run.
If anyone has insight into why the progress meter causes this issue when using jax 0.6.0 and the XLA flag I am interested to hear. This was not an issue with Jax 0.5.3
import os
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
class Robertson(eqx.Module):
k1: float
k2: float
k3: float
def __call__(self, t, y, args):
f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2]
f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2]
f2 = self.k2 * y[1] ** 2
return jnp.stack([f0, f1, f2])
@jax.jit
def main(k1, k2, k3):
robertson = Robertson(k1, k2, k3)
terms = diffrax.ODETerm(robertson)
t0 = 0.0
t1 = 100.0
y0 = jnp.array([1.0, 0.0, 0.0])
dt0 = 0.0002
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
sol = diffrax.diffeqsolve(
terms,
solver,
t0 = t0,
t1 = t1,
dt0 = dt0,
y0 = y0,
saveat = saveat,
max_steps=10000000000,
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8, pcoeff=0.4, icoeff=0.3),
progress_meter = diffrax.TextProgressMeter()
)
return sol
main(0.04, 3e7, 1e4)
start = time.time()
sol = main(0.04, 3e7, 1e4)
end = time.time()
print("Results:")
for ti, yi in zip(sol.ts, sol.ys):
print(f"t={ti.item()}, y={yi.tolist()}")
print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")
Oh, this is super weird. I'm able to reproduce this behaviour. Here's a smaller MWE:
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import diffrax
import equinox as eqx
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
@jax.jit
def main(y0):
terms = diffrax.ODETerm(lambda t, y, args: -y)
solver = diffrax.Euler()
progress_meter = diffrax.TextProgressMeter()
sol = diffrax.diffeqsolve(
terms, solver, t0=0, t1=1, dt0=0.2, y0=y0, progress_meter=progress_meter
)
return sol.ys
main(jnp.array(1.0))
If I had to guess it's something to do with the use of callbacks inside the progress meter. Perhaps a JAX-only MWE could be constructed demonstrating an issue with those.