Parametrize build system on CUDA major version
Try rebasing on head as well, please. Sorry for the slow review.
Rebased again to resolve new conflicts.
There was a recent change that has an impact on this implementation too:
https://github.com/jax-ml/jax/blob/09d903fc9a60726d5b4e57769a6bdbd424d44b08/jax_plugins/cuda/init.py#L123-L141
I think you need something similar with try-except to load either CUDA12 or CUDA13 wheels.
Thanks. I think this one can be addressed in a follow-up if that's OK.
@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter
@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter
sorry, forgot to mention two more targets that need deps updated: jax_cuda_plugin_wheel_size_test and jax_cuda_pjrt_wheel_size_test.
Done, sorry about that. Tried to fix the mypy linter failure too.
Getting one additional linter error internally on jax.bzl. I'm going to try and patch it myself