Split XLA into a separate shared library
Currently jaxlib consists of essentially three things:
- XLA, and its CPU and GPU backends.
- A little C++ runtime around XLA,
- Python bindings around both.
We build these all together into a single xla_extension.so.
We support at the moment 3 Python versions {3.6,3.7,3.8} and 5 CUDA variants {none,9.2,10.0,10.1,10.2}. Because we link everything into a single shared object, this means we make 3 * 5 release builds.
If we were to split the Python-binding parts from the compiler/runtime parts, we would instead only need to perform 3 + 5 builds. One way to do this would be to split our shared libraries into two: a) a thin Python binding layer (small, cheap to build) b) most of the C++ code
Note: both parts would still need to be built and distributed together, because we don't have a stable API or ABI between the two.
Possible refinement: distribute the Python-specific and CPU-only parts as jaxlib, and distribute a CUDA "plugin" separately (e.g., jaxlib-cuda92) that contains only the GPU plugin as a shared library.
I created an issue for this on the tensorflow github
FWIW, we have pre-built XLA packages here: https://github.com/elixir-nx/xla
This solution is pretty specific to our projects and so might not be useful for everyone trying to use XLA, but it ships with enough to at least experiment and build with XLA outside of TensorFlow.
I'm going to declare this fixed, because we've now started shipping our CUDA support as a PJRT plugin (jax-cuda12-pjrt). This eliminated the worst part of the "product" because we now build CUDA support exactly once with no Python version dependence. We may do the same for CPU support also at some point.