JAX unable to match NVIDIA driver version
Description
Similar to this issue and this discussion, I'm on a HPC cluster where the NVIDIA GPU's have CUDA 12.4. When I try to install jax with pip install jax[cuda12] or with pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, I get a warning suggesting that JAX is not heeding the 12.4 driver version, and is instead trying to use CUDA 12.5:
2024-07-19 09:58:22.022601: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA
version is 12.4 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than th
e ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update
your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
I understand that this isn't blocking, but I don't relish the idea of randomly slowed compilation times. Is there a reason JAX is unable to match the NVIDIA driver?
Note: I have been able to get JAX to work with my cluster's provided CUDA 12.1 drivers by using the following pip args:
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_cudnn89]==0.4.14
jaxlib==0.4.14+cuda12.cudnn89
But I anticipate that sooner or later I will want to use a JAX version that is newer than what my cluster provides.
Thanks!
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30 jaxlib: 0.4.30 numpy: 2.0.0 python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1 platform: uname_result(system='Linux', node='compute-g-17-155.o2.rc.hms.harvard.edu', release='3.10.0-1160.118.1.el7.x86_64', version='#1 SMP Wed Apr 24 16:01:50 UTC 2024', machine='x86_64')
$ nvidia-smi Fri Jul 19 10:41:24 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 Quadro RTX 8000 On | 00000000:AF:00.0 Off | 0 | | N/A 27C P0 23W / 250W | 166MiB / 46080MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 263594 C python 164MiB | +-----------------------------------------------------------------------------------------+
Having the same issues here. Any idea?
It is my understanding that the cuda forward compatibility exist for such reasons installing-compat-packages
So even though I haven't tried it, installing the necessary compat and then changing LD_LIBRARY_PATH before pip installing should do it.