alphafold icon indicating copy to clipboard operation
alphafold copied to clipboard

JAX 0.4.14 jaxlib cudnn error - non-docker installation

Open AssmannG opened this issue 1 year ago • 2 comments

Hi,

I am currently installing the latest Alphafold (from main branch, so 2.3.2 plus new features, new openmm=7.7.0 , no openmm patch, python=3.10 etc..) on a HPC cluster without Docker. I am confused about the jaxlib versions, as in the requirements. txt file it says jax=0.4.14 , but in the dockerfile it is still at 0.3.25.

CUDA=11.2.2 (or 11.8.0 or 11.1.0 , tried all of them) If I generate a conda env with: (from the docker file and requirements.txt)

channels:
  - pytorch
  - conda-forge
  - defaults
  - anaconda
  - bioconda
dependencies:
  - python==3.10
  - pip 
  - openmm==7.7.0
  - cudnn  # Change version if not compatible with current system
  - cudatoolkit
  - pdbfixer
  - hmmer==3.4
  - hhsuite==3.3.0
  - kalign2==2.04
  - pip:
    - absl-py==1.0.0
    - biopython==1.79 
    - chex==0.0.7 
    - dm-haiku==0.0.10 
    - immutabledict==2.0.0 
    - ml-collections==0.1.0  
    - numpy==1.24.3 
    - scipy==1.11.1 
    - tensorflow-cpu==2.13.0 
    - jax==0.4.14 
    - pandas==2.0.3 
    - dm-tree==0.1.8 

and then do

pip3 install --upgrade --no-cache-dir jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

If I do : python -c "import jax; print(f'Jax backend: {jax.default_backend()}') , I get Jax backend: gpu and if I check for jax.devices(), I get jax.devices output [gpu(id=0)]

alphafold runs ( 5 min slower than the old version on the cluster for a 60AA protein), but I get the Warning during runtime: Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.

As described in issue #88 , tensorflow is installed as tensorflow-cpu on purpose, if I understand correctly.

When I install the same with pip install --upgrade --no-cache-dir jax==0.4.14 jaxlib==0.4.14+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I dont get the Warning anymore, but alphafold fails after the "features" step during runtime with : Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR

Anyone any idea?

AssmannG avatar Feb 21 '24 10:02 AssmannG

This is weird, I agree. Myself I'm running the Docker version. And the inference took way longer than usual since the gpu wasn't detected for some reason. My CUDA version is 11.1.1 exactly like the Dockerfile.

Did you make any progress from your side ?

SkanderMarsit avatar Mar 08 '24 16:03 SkanderMarsit

Hi, I did some more tests and played around with different versions of various packages and I actually got it running by using a more recent conda version - which is weird from my understanding. The central conda version on our system was pretty old ( 4.10.3) and I decided to go for a very recent conda version (24.X.X) and that fixed the problem.

So I still used the: pip3 install --upgrade --no-cache-dir jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html and the environment.yml file:

channels:
  - pytorch
  - conda-forge
  - defaults
  - anaconda
  - bioconda
dependencies:
  - python==3.10
  - pip
  - openmm==7.7.0
  - cudnn  # Change version if not compatible with current system
  - cudatoolkit
  - pdbfixer
  - hmmer==3.4
  - hhsuite==3.3.0
  - kalign2==2.04
  - pip:
    - absl-py==1.0.0
    - biopython==1.79
    - chex==0.0.7
    - dm-haiku==0.0.10
    - immutabledict==2.0.0
    - ml-collections==0.1.0
    - numpy==1.24.3
    - scipy==1.11.1
    - tensorflow-cpu==2.13.0
    - jax==0.4.14
    - pandas==2.0.3
    - dm-tree==0.1.8

I still get some warnings, and I want to try using the very latest jax/jaxlib (0.4.23 I think) , but have not done so far.

Do you get "GPU" when trying:

python -c "import jax; print(f'Jax backend: {jax.default_backend()}') , I get Jax backend: gpu and if I check for jax.devices(), I get
jax.devices output [gpu(id=0)] 

AssmannG avatar Mar 15 '24 08:03 AssmannG