brax icon indicating copy to clipboard operation
brax copied to clipboard

ptxas version missmatch

Open misterguick opened this issue 1 year ago • 0 comments

Hi all,

I work on a cluster with cuda 12.4. I'm trying to use brax through torchRL wrapper. Among many issues I keep running across some weird compatibility warning. Here is the setup:

import brax.envs ; base_env = brax.envs.get_environment("halfcheetah")

which gives

W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

While this is not a blocking issue this remains very annoying to me.

If I run nvidia-smi I get

NVIDIA-SMI 550.78 Driver Version: 550.78 CUDA Version: 12.4

If I run I get

ptxas: NVIDIA (R) Ptx optimizing assembler Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:14:54_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0

There seems to be a missmatch between the warning and this last step ...

Here are the versions of the related packages:

brax 0.10.4 pypi_0 pypi jax 0.4.28 pypi_0 pypi jax-cuda12-pjrt 0.4.28 pypi_0 pypi jax-cuda12-plugin 0.4.28 pypi_0 pypi jaxlib 0.4.28+cuda12.cudnn89 pypi_0 pypi jaxopt 0.8.3 pypi_0 pypi jaxtyping 0.2.29 pypi_0 pypi

I tried reinstalling everything (jax, brax, cuda) from pip and conda but this never removed the warning.

Thank you very much in advance !

misterguick avatar May 30 '24 22:05 misterguick