jaxDecomp icon indicating copy to clipboard operation
jaxDecomp copied to clipboard

Easier way to chose cuda versions for end users

Open ASKabalan opened this issue 9 months ago • 4 comments

Currently in the CMake, the cuda version is set to be 12.2.

jaxDecomp (and cuDecomp) can be compiled with 11.8 (no specific cuda 12 code)

JAX 0.4.26 and above no longer supports cuda 11, but some machines do not have the latest drivers so some users have to use JAX 0.4.25.

I propose to allow users to chose which version to compile jaxDecomp with like so

By default 12.2

pip install jaxdecomp 

or

pip install jaxdecomp[cuda11]
pip install jaxdecomp[cuda12]

But obviously we don't dowload the nvidia wheels, we still expect the user to have the modules loaded.

ASKabalan avatar Apr 29 '24 07:04 ASKabalan

This would be really great, to be able to install already built versions, the thing I don't know is if it's possible to do in a way that incorporates NVHPC, both for technical and licensing reasons.

Very interested if you have thoughts on how to do it.

EiffL avatar Apr 29 '24 18:04 EiffL

We don't need to incorporate NVHPC compiler since it is only used for compiling and not runtime.

All runtime requirements exist as wheels installable via PIP (cuda NCCL cufft etc ...)

We would have to build distributable wheels and host them then they can be installed via pip install -f link or maybe push them to pypi.. not sure how this is done.

https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-f

ASKabalan avatar Apr 29 '24 18:04 ASKabalan

great great great. Like for jax, ultimately we probably want to give two install strategies, one with local install, one with all pip :-)

EiffL avatar Apr 29 '24 19:04 EiffL

Started working on this in https://github.com/DifferentiableUniverseInitiative/jaxDecomp/tree/push-to-pypi

ASKabalan avatar Jun 18 '24 20:06 ASKabalan

Closing as it is now dead simple to install jaxDecomp

ASKabalan avatar Oct 30 '24 15:10 ASKabalan