jax icon indicating copy to clipboard operation
jax copied to clipboard

Compilation takes way more time in new version

Open stergiosba opened this issue 2 months ago • 3 comments

Description

Hey Jax team thanks for the amazing work you are all doing with this project.

I am working with Jax for quite sometime now and ever since the 0.4.27 update I have been getting x5 to x10 more compilation time on my code without changes.

Its quite a bit of code and I don't know how you guys can reproduce what I am seeing but I will just report some numbers on my machine using JAX_LOG_COMPILES :

Jax Version 0.4.26:

Finished jaxpr to MLIR module conversion jit(train) in 2.1959640979766846 sec
Finished XLA compilation of jit(train) in 21.94405436515808 sec

Jax Version 0.4.27 and 0.4.28:

Finished jaxpr to MLIR module conversion jit(train) in 2.4314699172973633 sec
Finished XLA compilation of jit(train) in 97.583487033844 sec

Again this is the same code with exactly the same dependencies only difference is Jax's version. Also noticed that recompilation is triggered for parts of my code in version 0.4.28.

What do you guys think?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='apollo', release='6.5.0-26-generic', version='#26~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Mar 12 10:22:43 UTC 2', machine='x86_64')


$ nvidia-smi
Wed May 15 15:53:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3070 Ti     Off |   00000000:01:00.0 Off |                  N/A |
| 33%   42C    P2             71W /  290W |     904MiB /   8192MiB |     19%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1850      G   /usr/lib/xorg/Xorg                            365MiB |
|    0   N/A  N/A      2091      G   /usr/bin/gnome-shell                           70MiB |
|    0   N/A  N/A    622325      G   ...irefox/4090/usr/lib/firefox/firefox          0MiB |
|    0   N/A  N/A    650168      G   ...erProcess --variations-seed-version         59MiB |
|    0   N/A  N/A    841800      C   python                                        160MiB |
+-----------------------------------------------------------------------------------------+

stergiosba avatar May 15 '24 19:05 stergiosba