CPU slowdown with new runtime (v0.4.32 and newer)
Description
Thanks a lot for your efforts in building JAX, I love working with it!
On my MacBook Pro CPU (M3), my differentiable simulator runs 5x to 10x slower on new versions of JAX (v0.4.32 or newer) as compared to older versions (v0.4.31 or older).
Setting the following xla-flag fixes the issue in newer versions of JAX for me (i.e. speed is as before):
import os
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
Unfortunately, 5x slower runtime would probably kill any relevant use-case of my simulator. As such, I have two questions:
- Will the old CPU runtime continue to be maintained via the XLA flag?
- Do you have any obvious candidates for operations that could cause this behavior? Any ideas on where I should start looking?
Related to #26021 and #25808
Thanks a lot!
To reproduce
As my code-base is fairly large, I cannot easily provide a self-contained example without relying on my toolbox. To reproduce:
pip install jaxley==0.1.2
and
from jax import config
config.update("jax_platform_name", "cpu")
config.update("jax_enable_x64", True)
import time
from jax import jit
import jaxley as jx
from jaxley.channels import HH
import os
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1] + [b // 2 for b in range(0, 2**8 - 2)])
cell.insert(HH())
cell.branch(0).comp(0).record()
cell.make_trainable("radius")
params = cell.get_parameters()
@jit
def simulate(params):
return jx.integrate(cell, params=params, t_max=1_000.0)
start_time = time.time()
simulate(params).block_until_ready()
print("Compile time: ", time.time() - start_time)
start_time = time.time()
simulate(params).block_until_ready()
print("Run time: ", time.time() - start_time)
On my MacBook Pro (and JAX v0.5.0, jaxlib v0.5.0, Python 3.12), I get:
Compile time: 7.672
Run time: 2.067
Removing os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false', I get:
Compile time: 13.620
Run time: 11.293
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.0
jaxlib: 0.5.0
numpy: 2.2.2
python: 3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 10:37:40) [Clang 14.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Michaels-MBP', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:11:08 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8122', machine='arm64')
EDIT: When using float, the gap between the old versions and new versions of JAX is even more prominent (almost 10x then...)
config.update("jax_enable_x64", False)
Ping @ezhulenev and @penpornk for CPU thunks runtime performance
As noted, it's hard to say too much without a more minimal reproducer, but to answer the specific questions:
Will the old CPU runtime continue to be maintained via the XLA flag?
I think there continue to be enough performance regressions that we would like to keep this working for the time being. I think the right approach is to suggest the use of that flag for now with the assumption that the reported performance regressions will be reasonably addressed before it is removed. @ezhulenev and @penpornk might be able to comment more about timelines on the XLA side!
Do you have any obvious candidates for operations that could cause this behavior? Any ideas on where I should start looking?
The most common culprit in my experience is loops (like scan or while_loops), although sometimes these can be implicit. If you have explicit scans in your library, you can experiment with trading off compile time with runtime performance using the unroll parameter (increase that for longer compile times, but typically better runtime performance.)
We plan to fix performance regression in then next couple of weeks (@cota), and only after that we'll start removing the old runtime.
That's great to know, thanks a lot for the quick response!
My simulator is an ODE which indeed uses a scan (across 40k timepoints), so this might well be the culprit. Unrolling is not an option for me because compile time becomes excessive.
For now, I will simply use the old CPU runtime via the XLA flag. Thank you for the suggestions!
@michaeldeistler — There have been some significant improvements in the CPU runtime with small loops. I'd be interested to know if things are better with the latest JAX (v0.6.0) on your benchmarks.
Cool!
On JAX==0.6.0, I now have:
Compile time: 3.9083
Run time: 2.384
In other words: Compile time is almost 2x faster than ever before! Run time still a bit slower (~10-20%) than old JAX versions (<= v0.4.31), but much improved compared to, e.g., v0.4.35.
It would of course be fantastic to overcome the remaining 10-20% slowdown in runtime, but---for my applications---this slowdown is no longer a blocker. I will remove the version pin on my toolbox :)
Thanks for the heads-up, this is great! Michael
Can you run it with XLA_FLAGS=--xla_dump_to=/path/to/dump/hlos, we'll take a look why we still have remaining 10% regression (+ @WillFroom)
Thanks for the repro, collected performance profiles and the problem is that we call libm exp for f64 data time on a hot path. We'll be fixing this soon.
And 10% regression you see is from overheads of concurrent execution in the small loop. I'll look into that today, that might be easy to fix.
I have a "fix", it solves the 10% regression, but adds 10% regression to many important internal benchmarks :) I'll need more time to find a good cost model for runtime heuristics.
Any update on this issue?
Not really, I could not find a good heuristic that will make things uniformly better. One option is to add a flag to XLA that will give some control of scheduling strategy to the user. Will this work for you?
Thanks for your answer. Not ideal - the current version is still a regression for my use case wrt v0.4.31. My use case, unfortunately, combines several of the affected pieces: large scans with thousands of loop steps, need for float64 precision, everything is jitted, and many of those running in parallel on all CPU threads.
@seantalts should have a fix for f64 precision soon
Yep, working on it. Current benchmarks show pretty dramatic improvements!
name old time/op new time/op delta
BM_ExpF64/128/process_time 2.77µs ± 2% 1.03µs ± 5% -62.91% (p=0.000 n=40+39)
BM_ExpF64/256/process_time 4.79µs ± 2% 1.32µs ± 6% -72.45% (p=0.000 n=40+40)
BM_ExpF64/512/process_time 8.85µs ± 2% 1.89µs ± 5% -78.60% (p=0.000 n=40+40)
BM_ExpF64/1024/process_time 16.9µs ± 2% 3.1µs ± 7% -81.88% (p=0.000 n=40+40)
BM_ExpF64/4096/process_time 65.8µs ± 3% 10.1µs ± 8% -84.63% (p=0.000 n=39+40)
BM_ExpF64_Aot/128/process_time 2.87µs ± 2% 2.05µs ± 5% -28.57% (p=0.000 n=40+39)
BM_ExpF64_Aot/256/process_time 4.99µs ± 2% 3.34µs ± 5% -33.04% (p=0.000 n=40+40)
BM_ExpF64_Aot/512/process_time 9.24µs ± 2% 5.97µs ± 6% -35.39% (p=0.000 n=40+40)
BM_ExpF64_Aot/1024/process_time 17.7µs ± 2% 11.1µs ± 5% -37.54% (p=0.000 n=40+40)
BM_ExpF64_Aot/4096/process_time 69.0µs ± 2% 42.5µs ± 7% -38.37% (p=0.000 n=40+40)
The changes have landed in XLA (final PR). Would it be easy for you to re-run your benchmark/test with this version and see how it compares now?
And the easiest way to test that is probably to try with a jax/jaxlib nightly in a day or so.
Hi everyone,
I just tried with the newest JAX nightly (JAX v0.6.3.20250618, jaxlib v0.6.3.20250619) and I get:
Compile time: 3.5977
Run time: 1.5199
This is a ~40% runtime speedup compared to v0.6.0, and a runtime speedup of ~25% compared to v0.4.31 and older. Thank you so much for your efforts, this is really amazing! Feel free to close this issue.
For future reference, here is an overview of the compile time and runtime of all versions that I checked:
# JAX v0.4.31
Compile time: 7.672
Run time: 2.067
# JAX v0.5.0
Compile time: 13.620
Run time: 11.293
# JAX v0.6.0
Compile time: 3.9083
Run time: 2.384
# JAX v0.6.3.20250618, jaxlib v0.6.3.20250619 (nightly)
Compile time: 3.5977
Run time: 1.5199
Cheers Michael
I confirm that v0.6.3 is a solid improvements, well done! Can you please tell us what was changed? Is it the scan? the float64? Is it also fixing this issue? Thanks
Thank you both for testing!
Can you please tell us what was changed? Is it the scan? the float64?
The main change I'm aware of was to make the exponential function on float64 much faster. @ezhulenev would know if there were also other changes that would have helped here.
Is it also fixing https://github.com/jax-ml/jax/issues/26021?
I don't see any exp in the hlo for this issue, so probably not.
Can we close this bug? Is everything filed here addressed?
For my side it can be closed.
Same here: for me it can be closed. Thanks
Unfortunately, I noticed that computing the gradient of my model is still experiencing a large performance regression compared to v0.4.31 on CPU. I tested multiple versions (0.4.35, 0.6.0, 0.7.1) and all of them are slower than v0.4.31. However, the most recent JAX version, 0.7.1, is particularly slow at computing the gradient: about 8x slower than v0.4.31. Below are the HLOs:
Any ideas? Thank you very much for your help! Michael
@WillFroom @ezhulenev ptal at @michaeldeistler 's repro?
Hi @michaeldeistler, the particular slow down in 0.7.1 looks like the same root cause as #31284, which I have a in-flight fix for here: https://github.com/openxla/xla/pull/30679, I will look into the remaining slowdown vs v0.4.31
Perfect, thanks! Let me know if HLOs of intermediate JAX versions (e.g., 0.6.0, 0.4.35,...) would be useful.