jax-windows-builder icon indicating copy to clipboard operation
jax-windows-builder copied to clipboard

Is it possible to build a wheel for cuDNN v8.8?

Open sharpe5 opened this issue 1 year ago • 2 comments

I'm trying to find a method of installing JAX v4.11 from here.

Q. Would it be possible to build a wheel for cudnn v8.8?

The reason? Unfortunately, there are no Anaconda builds for cuDNN v8.6 or v8.9; the best one I could find was cuDNN v8.8, see: https://anaconda.org/conda-forge/cudnn/files

Appendix A

Here is how I installed JAX v0.3.25 on Windows + Anaconda. It is a completely self-contained method that does not rely on any external Windows installers from nVIDIA.

BTW, I could create a pull request with these extra docs if it would help others?

# Install Anaconda or Miniconda
conda create -n py310jax python=3.10 -y
conda activate py310jax
conda install -c conda-forge cudatoolkit=11.1 cudnn -y
# Tensorflow 2.10 was the last version to support CUDA+GPU on Windows.
pip install "tensorflow<2.11"
# Install jaxlib
#   - Download file "jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl" from "https://whls.blob.core.windows.net/unstable/index.html"
pip install jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl
# Install matching version of jax
pip install jax==0.3.25
# Now we can run JAX-based Python code on Windows.

sharpe5 avatar Jul 03 '23 08:07 sharpe5

It is unmaintainable. The matched version pairs are selected only because the jax officially have that combination in their releases. But you should be able to build jaxlib from source fairly easy by using https://github.com/cloudhan/jax-windows-builder/blob/overhaul/build-jaxlib.ps1.

All you need is msys2, git, msvc 2019 installation, and python with required python deps.

cloudhan avatar Jul 03 '23 09:07 cloudhan

Thanks, I like your comment:

The matched version pairs are selected only because the jax officially have that combination in their releases.

Knowing this, I am prepared to run everything in experimental mode and rely on unit tests to determine code validity. This seemed to work nicely:

  • Windows 10 x64
  • Python 3.10
  • cudatoolkit 11.8
  • cudnn 8.2.1
  • Tensorflow 2.10.1
  • jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl downloaded from this wheel repo.

Example install on Anaconda:

conda create -n py310jax python=3.10 -y
conda activate py310jax
conda install -c conda-forge cudatoolkit=11.8 cudnn -y
pip install "tensorflow<2.11"
pip install "jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl"
pip install jax=0.3.25
pip install tensorflow-probability==0.18
conda install -c nvidia cuda-nvcc -y
conda install -c conda-forge pandas tabulate pydantic pyarrow scikit-learn numpy numba -y

sharpe5 avatar Jul 03 '23 12:07 sharpe5