jax icon indicating copy to clipboard operation
jax copied to clipboard

CPU slowdown with new runtime (v0.4.32 and newer)

Open michaeldeistler opened this issue 11 months ago • 17 comments

Description

Thanks a lot for your efforts in building JAX, I love working with it!

On my MacBook Pro CPU (M3), my differentiable simulator runs 5x to 10x slower on new versions of JAX (v0.4.32 or newer) as compared to older versions (v0.4.31 or older).

Setting the following xla-flag fixes the issue in newer versions of JAX for me (i.e. speed is as before):

import os
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'

Unfortunately, 5x slower runtime would probably kill any relevant use-case of my simulator. As such, I have two questions:

  • Will the old CPU runtime continue to be maintained via the XLA flag?
  • Do you have any obvious candidates for operations that could cause this behavior? Any ideas on where I should start looking?

Related to #26021 and #25808

Thanks a lot!

To reproduce

As my code-base is fairly large, I cannot easily provide a self-contained example without relying on my toolbox. To reproduce:

pip install jaxley==0.1.2

and

from jax import config
config.update("jax_platform_name", "cpu")
config.update("jax_enable_x64", True)

import time
from jax import jit
import jaxley as jx
from jaxley.channels import HH

import os
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'

comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1] + [b // 2 for b in range(0, 2**8 - 2)])
cell.insert(HH())
cell.branch(0).comp(0).record()
cell.make_trainable("radius")
params = cell.get_parameters()

@jit
def simulate(params):
    return jx.integrate(cell, params=params, t_max=1_000.0)

start_time = time.time()
simulate(params).block_until_ready()
print("Compile time: ", time.time() - start_time)

start_time = time.time()
simulate(params).block_until_ready()
print("Run time: ", time.time() - start_time)

On my MacBook Pro (and JAX v0.5.0, jaxlib v0.5.0, Python 3.12), I get:

Compile time:  7.672
Run time:      2.067

Removing os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false', I get:

Compile time:  13.620
Run time:      11.293

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 10:37:40) [Clang 14.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Michaels-MBP', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:11:08 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8122', machine='arm64')

michaeldeistler avatar Jan 28 '25 10:01 michaeldeistler

EDIT: When using float, the gap between the old versions and new versions of JAX is even more prominent (almost 10x then...)

config.update("jax_enable_x64", False)

michaeldeistler avatar Jan 28 '25 10:01 michaeldeistler

Ping @ezhulenev and @penpornk for CPU thunks runtime performance

As noted, it's hard to say too much without a more minimal reproducer, but to answer the specific questions:

Will the old CPU runtime continue to be maintained via the XLA flag?

I think there continue to be enough performance regressions that we would like to keep this working for the time being. I think the right approach is to suggest the use of that flag for now with the assumption that the reported performance regressions will be reasonably addressed before it is removed. @ezhulenev and @penpornk might be able to comment more about timelines on the XLA side!

Do you have any obvious candidates for operations that could cause this behavior? Any ideas on where I should start looking?

The most common culprit in my experience is loops (like scan or while_loops), although sometimes these can be implicit. If you have explicit scans in your library, you can experiment with trading off compile time with runtime performance using the unroll parameter (increase that for longer compile times, but typically better runtime performance.)

dfm avatar Jan 28 '25 17:01 dfm

We plan to fix performance regression in then next couple of weeks (@cota), and only after that we'll start removing the old runtime.

ezhulenev avatar Jan 28 '25 17:01 ezhulenev

That's great to know, thanks a lot for the quick response!

My simulator is an ODE which indeed uses a scan (across 40k timepoints), so this might well be the culprit. Unrolling is not an option for me because compile time becomes excessive.

For now, I will simply use the old CPU runtime via the XLA flag. Thank you for the suggestions!

michaeldeistler avatar Jan 28 '25 19:01 michaeldeistler

@michaeldeistler — There have been some significant improvements in the CPU runtime with small loops. I'd be interested to know if things are better with the latest JAX (v0.6.0) on your benchmarks.

dfm avatar Apr 17 '25 14:04 dfm

Cool!

On JAX==0.6.0, I now have:

Compile time:  3.9083
Run time:      2.384

In other words: Compile time is almost 2x faster than ever before! Run time still a bit slower (~10-20%) than old JAX versions (<= v0.4.31), but much improved compared to, e.g., v0.4.35.

It would of course be fantastic to overcome the remaining 10-20% slowdown in runtime, but---for my applications---this slowdown is no longer a blocker. I will remove the version pin on my toolbox :)

Thanks for the heads-up, this is great! Michael

michaeldeistler avatar Apr 17 '25 14:04 michaeldeistler

Can you run it with XLA_FLAGS=--xla_dump_to=/path/to/dump/hlos, we'll take a look why we still have remaining 10% regression (+ @WillFroom)

ezhulenev avatar Apr 17 '25 17:04 ezhulenev

Here you go!

hlo_v_0_4_31.zip hlo_v_0_6_0.zip

Thank you so much! Michael

michaeldeistler avatar Apr 17 '25 18:04 michaeldeistler

Thanks for the repro, collected performance profiles and the problem is that we call libm exp for f64 data time on a hot path. We'll be fixing this soon.

And 10% regression you see is from overheads of concurrent execution in the small loop. I'll look into that today, that might be easy to fix.

ezhulenev avatar Apr 23 '25 21:04 ezhulenev

I have a "fix", it solves the 10% regression, but adds 10% regression to many important internal benchmarks :) I'll need more time to find a good cost model for runtime heuristics.

ezhulenev avatar Apr 23 '25 23:04 ezhulenev

Any update on this issue?

vboulanger avatar Jun 06 '25 14:06 vboulanger

Not really, I could not find a good heuristic that will make things uniformly better. One option is to add a flag to XLA that will give some control of scheduling strategy to the user. Will this work for you?

ezhulenev avatar Jun 06 '25 16:06 ezhulenev

Thanks for your answer. Not ideal - the current version is still a regression for my use case wrt v0.4.31. My use case, unfortunately, combines several of the affected pieces: large scans with thousands of loop steps, need for float64 precision, everything is jitted, and many of those running in parallel on all CPU threads.

vboulanger avatar Jun 09 '25 15:06 vboulanger

@seantalts should have a fix for f64 precision soon

ezhulenev avatar Jun 09 '25 18:06 ezhulenev

Yep, working on it. Current benchmarks show pretty dramatic improvements!

name                             old time/op          new time/op          delta
BM_ExpF64/128/process_time       2.77µs ± 2%          1.03µs ± 5%  -62.91%  (p=0.000 n=40+39)
BM_ExpF64/256/process_time       4.79µs ± 2%          1.32µs ± 6%  -72.45%  (p=0.000 n=40+40)
BM_ExpF64/512/process_time       8.85µs ± 2%          1.89µs ± 5%  -78.60%  (p=0.000 n=40+40)
BM_ExpF64/1024/process_time      16.9µs ± 2%           3.1µs ± 7%  -81.88%  (p=0.000 n=40+40)
BM_ExpF64/4096/process_time      65.8µs ± 3%          10.1µs ± 8%  -84.63%  (p=0.000 n=39+40)
BM_ExpF64_Aot/128/process_time   2.87µs ± 2%          2.05µs ± 5%  -28.57%  (p=0.000 n=40+39)
BM_ExpF64_Aot/256/process_time   4.99µs ± 2%          3.34µs ± 5%  -33.04%  (p=0.000 n=40+40)
BM_ExpF64_Aot/512/process_time   9.24µs ± 2%          5.97µs ± 6%  -35.39%  (p=0.000 n=40+40)
BM_ExpF64_Aot/1024/process_time  17.7µs ± 2%          11.1µs ± 5%  -37.54%  (p=0.000 n=40+40)
BM_ExpF64_Aot/4096/process_time  69.0µs ± 2%          42.5µs ± 7%  -38.37%  (p=0.000 n=40+40)

seantalts avatar Jun 09 '25 20:06 seantalts

The changes have landed in XLA (final PR). Would it be easy for you to re-run your benchmark/test with this version and see how it compares now?

seantalts avatar Jun 18 '25 16:06 seantalts

And the easiest way to test that is probably to try with a jax/jaxlib nightly in a day or so.

hawkinsp avatar Jun 18 '25 21:06 hawkinsp

Hi everyone,

I just tried with the newest JAX nightly (JAX v0.6.3.20250618, jaxlib v0.6.3.20250619) and I get:

Compile time:  3.5977
Run time:      1.5199

This is a ~40% runtime speedup compared to v0.6.0, and a runtime speedup of ~25% compared to v0.4.31 and older. Thank you so much for your efforts, this is really amazing! Feel free to close this issue.

For future reference, here is an overview of the compile time and runtime of all versions that I checked:

# JAX v0.4.31
Compile time:  7.672
Run time:      2.067

# JAX v0.5.0
Compile time:  13.620
Run time:      11.293

# JAX v0.6.0
Compile time:  3.9083
Run time:      2.384

# JAX v0.6.3.20250618, jaxlib v0.6.3.20250619 (nightly)
Compile time:  3.5977
Run time:      1.5199

Cheers Michael

michaeldeistler avatar Jun 20 '25 07:06 michaeldeistler

I confirm that v0.6.3 is a solid improvements, well done! Can you please tell us what was changed? Is it the scan? the float64? Is it also fixing this issue? Thanks

vboulanger avatar Jun 20 '25 13:06 vboulanger

Thank you both for testing!

Can you please tell us what was changed? Is it the scan? the float64?

The main change I'm aware of was to make the exponential function on float64 much faster. @ezhulenev would know if there were also other changes that would have helped here.

Is it also fixing https://github.com/jax-ml/jax/issues/26021?

I don't see any exp in the hlo for this issue, so probably not.

seantalts avatar Jun 23 '25 21:06 seantalts

Can we close this bug? Is everything filed here addressed?

hawkinsp avatar Jun 24 '25 13:06 hawkinsp

For my side it can be closed.

michaeldeistler avatar Jun 24 '25 13:06 michaeldeistler

Same here: for me it can be closed. Thanks

vboulanger avatar Jun 24 '25 13:06 vboulanger

Unfortunately, I noticed that computing the gradient of my model is still experiencing a large performance regression compared to v0.4.31 on CPU. I tested multiple versions (0.4.35, 0.6.0, 0.7.1) and all of them are slower than v0.4.31. However, the most recent JAX version, 0.7.1, is particularly slow at computing the gradient: about 8x slower than v0.4.31. Below are the HLOs:

v0_4_31.zip v0_7_1.zip

Any ideas? Thank you very much for your help! Michael

michaeldeistler avatar Aug 27 '25 13:08 michaeldeistler

@WillFroom @ezhulenev ptal at @michaeldeistler 's repro?

hawkinsp avatar Aug 27 '25 14:08 hawkinsp

Hi @michaeldeistler, the particular slow down in 0.7.1 looks like the same root cause as #31284, which I have a in-flight fix for here: https://github.com/openxla/xla/pull/30679, I will look into the remaining slowdown vs v0.4.31

WillFroom avatar Aug 27 '25 15:08 WillFroom

Perfect, thanks! Let me know if HLOs of intermediate JAX versions (e.g., 0.6.0, 0.4.35,...) would be useful.

michaeldeistler avatar Aug 27 '25 15:08 michaeldeistler