jaxDecomp
jaxDecomp copied to clipboard
Easier way to chose cuda versions for end users
Currently in the CMake, the cuda version is set to be 12.2.
jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code)
JAX 0.4.26 and above no longer supports cuda 11, but some machines do not have the latest drivers so some users have to use JAX 0.4.25.
I propose to allow users to chose which version to compile jaxDecomp with like so
By default 12.2
pip install jaxdecomp
or
pip install jaxdecomp[cuda11]
pip install jaxdecomp[cuda12]
But obviously we don't dowload the nvidia wheels, we still expect the user to have the modules loaded.
This would be really great, to be able to install already built versions, the thing I don't know is if it's possible to do in a way that incorporates NVHPC, both for technical and licensing reasons.
Very interested if you have thoughts on how to do it.
We don't need to incorporate NVHPC compiler since it is only used for compiling and not runtime.
All runtime requirements exist as wheels installable via PIP (cuda NCCL cufft etc ...)
We would have to build distributable wheels and host them then they can be installed via pip install -f link
or maybe push them to pypi.. not sure how this is done.
https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-f
great great great. Like for jax, ultimately we probably want to give two install strategies, one with local install, one with all pip :-)
Started working on this in https://github.com/DifferentiableUniverseInitiative/jaxDecomp/tree/push-to-pypi
Closing as it is now dead simple to install jaxDecomp