XLA.jl
XLA.jl copied to clipboard
Automatically install XLA GPU
Jaxlib's manual install is fiddly:
# install jaxlib
PYTHON_VERSION=cp37 # alternatives: cp36, cp37, cp38
CUDA_VERSION=cuda92 # alternatives: cuda92, cuda100, cuda101, cuda102
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.46-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax # install jax
We can figure out all these versions at build time, so we should just do this automatically.