jax
jax copied to clipboard
Jax is not compatible with the latest Pyarrow
Description
I am not able to create a Conda environment with jax and pyarrow>=14.0. I see that it's because of the pinned libgrpc version (>=1.58.1, <1.59.0a0)
System info (python version, jaxlib version, accelerator, etc.)
Here is a simple example to reproduce:
mamba create -n lala python=3.11 jax pyarrow=15 --dry-run
Hi - thanks for the report! JAX itself doesn't pin any libgrpc version. It looks like this pinning is likely coming from the jaxlib conda-forge package: https://github.com/conda-forge/jaxlib-feedstock. I'm not sure why such a specific version is pinned. The best place to raise this issue would probably be there.
In the meantime, you may be able to work around this by installing pyarrow via mamba, and installing jax via pip.
As Jake says, this is a property of the (community-supported) conda-forge package. You can probably just install JAX's pip package inside your conda instance. JAX vendors dependencies like grpc internally in the pip package, so they don't show up as dependencies.
Since we don't ship the conda-forge packages, there's not much else we can do on our end!
It looks like this has been resolved in conda-forge. Thanks!