brax
brax copied to clipboard
ptxas version missmatch
Hi all,
I work on a cluster with cuda 12.4. I'm trying to use brax through torchRL wrapper. Among many issues I keep running across some weird compatibility warning. Here is the setup:
import brax.envs ; base_env = brax.envs.get_environment("halfcheetah")
which gives
W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
While this is not a blocking issue this remains very annoying to me.
If I run nvidia-smi I get
NVIDIA-SMI 550.78 Driver Version: 550.78 CUDA Version: 12.4
If I run I get
ptxas: NVIDIA (R) Ptx optimizing assembler Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:14:54_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0
There seems to be a missmatch between the warning and this last step ...
Here are the versions of the related packages:
brax 0.10.4 pypi_0 pypi jax 0.4.28 pypi_0 pypi jax-cuda12-pjrt 0.4.28 pypi_0 pypi jax-cuda12-plugin 0.4.28 pypi_0 pypi jaxlib 0.4.28+cuda12.cudnn89 pypi_0 pypi jaxopt 0.8.3 pypi_0 pypi jaxtyping 0.2.29 pypi_0 pypi
I tried reinstalling everything (jax, brax, cuda) from pip and conda but this never removed the warning.
Thank you very much in advance !