jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX unable to match NVIDIA driver version

Open jonahpearl opened this issue 1 year ago • 1 comments

Description

Similar to this issue and this discussion, I'm on a HPC cluster where the NVIDIA GPU's have CUDA 12.4. When I try to install jax with pip install jax[cuda12] or with pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, I get a warning suggesting that JAX is not heeding the 12.4 driver version, and is instead trying to use CUDA 12.5:

2024-07-19 09:58:22.022601: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA
 version is 12.4 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than th
e ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update
 your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

I understand that this isn't blocking, but I don't relish the idea of randomly slowed compilation times. Is there a reason JAX is unable to match the NVIDIA driver?

Note: I have been able to get JAX to work with my cluster's provided CUDA 12.1 drivers by using the following pip args:

-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_cudnn89]==0.4.14
jaxlib==0.4.14+cuda12.cudnn89

But I anticipate that sooner or later I will want to use a JAX version that is newer than what my cluster provides.

Thanks!

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.30 jaxlib: 0.4.30 numpy: 2.0.0 python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1 platform: uname_result(system='Linux', node='compute-g-17-155.o2.rc.hms.harvard.edu', release='3.10.0-1160.118.1.el7.x86_64', version='#1 SMP Wed Apr 24 16:01:50 UTC 2024', machine='x86_64')

$ nvidia-smi Fri Jul 19 10:41:24 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 Quadro RTX 8000 On | 00000000:AF:00.0 Off | 0 | | N/A 27C P0 23W / 250W | 166MiB / 46080MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 263594 C python 164MiB | +-----------------------------------------------------------------------------------------+

jonahpearl avatar Jul 19 '24 14:07 jonahpearl

Having the same issues here. Any idea?

epignatelli avatar Aug 20 '24 14:08 epignatelli

It is my understanding that the cuda forward compatibility exist for such reasons installing-compat-packages

So even though I haven't tried it, installing the necessary compat and then changing LD_LIBRARY_PATH before pip installing should do it.

alonfnt avatar Sep 23 '24 05:09 alonfnt