jax
jax copied to clipboard
Compilation takes way more time in new version
Description
Hey Jax team thanks for the amazing work you are all doing with this project.
I am working with Jax for quite sometime now and ever since the 0.4.27 update I have been getting x5 to x10 more compilation time on my code without changes.
Its quite a bit of code and I don't know how you guys can reproduce what I am seeing but I will just report some numbers on my machine using JAX_LOG_COMPILES
:
Jax Version 0.4.26:
Finished jaxpr to MLIR module conversion jit(train) in 2.1959640979766846 sec
Finished XLA compilation of jit(train) in 21.94405436515808 sec
Jax Version 0.4.27 and 0.4.28:
Finished jaxpr to MLIR module conversion jit(train) in 2.4314699172973633 sec
Finished XLA compilation of jit(train) in 97.583487033844 sec
Again this is the same code with exactly the same dependencies only difference is Jax's version. Also noticed that recompilation is triggered for parts of my code in version 0.4.28.
What do you guys think?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='apollo', release='6.5.0-26-generic', version='#26~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Mar 12 10:22:43 UTC 2', machine='x86_64')
$ nvidia-smi
Wed May 15 15:53:39 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3070 Ti Off | 00000000:01:00.0 Off | N/A |
| 33% 42C P2 71W / 290W | 904MiB / 8192MiB | 19% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1850 G /usr/lib/xorg/Xorg 365MiB |
| 0 N/A N/A 2091 G /usr/bin/gnome-shell 70MiB |
| 0 N/A N/A 622325 G ...irefox/4090/usr/lib/firefox/firefox 0MiB |
| 0 N/A N/A 650168 G ...erProcess --variations-seed-version 59MiB |
| 0 N/A N/A 841800 C python 160MiB |
+-----------------------------------------------------------------------------------------+