jax icon indicating copy to clipboard operation
jax copied to clipboard

Parametrize build system on CUDA major version

Open olupton opened this issue 7 months ago • 2 comments

olupton avatar May 23 '25 09:05 olupton

Try rebasing on head as well, please. Sorry for the slow review.

hawkinsp avatar Jun 04 '25 17:06 hawkinsp

Rebased again to resolve new conflicts.

olupton avatar Jun 16 '25 14:06 olupton

There was a recent change that has an impact on this implementation too:

https://github.com/jax-ml/jax/blob/09d903fc9a60726d5b4e57769a6bdbd424d44b08/jax_plugins/cuda/init.py#L123-L141

I think you need something similar with try-except to load either CUDA12 or CUDA13 wheels.

Thanks. I think this one can be addressed in a follow-up if that's OK.

olupton avatar Jun 25 '25 10:06 olupton

@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter

MichaelHudgins avatar Jul 17 '25 20:07 MichaelHudgins

@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter

sorry, forgot to mention two more targets that need deps updated: jax_cuda_plugin_wheel_size_test and jax_cuda_pjrt_wheel_size_test.

ybaturina avatar Jul 17 '25 20:07 ybaturina

Done, sorry about that. Tried to fix the mypy linter failure too.

olupton avatar Jul 18 '25 16:07 olupton

Getting one additional linter error internally on jax.bzl. I'm going to try and patch it myself

MichaelHudgins avatar Jul 18 '25 17:07 MichaelHudgins