dalle-mini
dalle-mini copied to clipboard
Incompatible jax or jaxlib version in tools/inference/inference_pipeline.ipynb
In the 3rd code block of the inference pipeline notebook, it has this error
RuntimeError Traceback (most recent call last)
4 frames
/usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version) 67 f'incompatible with jax version {jax_version}. Please ' 68 'update your jax and/or jaxlib packages.') ---> 69 raise RuntimeError(msg) 70 71 return _jaxlib_version
RuntimeError: jaxlib version 0.4.7 is newer than and incompatible with jax version 0.3.25. Please update your jax and/or jaxlib packages.
I got the following error with jax.
AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
will solve your problem.
Basically jax and jaxlib version mismatch.
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
will solve your problem. Basically jax and jaxlib version mismatch.
I had the same issue, I tried to install jaxlib with version=0.3.25. However, new issue shows up saying orbax 0.1.7 and chex 0.1.7 has requirement jax>= 0.4.6.
@saeyrjac565 : Looks like orbax and chex need latest jax. Do you need this for running inference using dalle-mini ? Both jax and jax lib should be at 0.3.25. That is what i understood. For me jax was 0.3.25, only lib was still at 0.4.7, which i solved using the other comment i mentioned.
Also, I have only single GPU with 12GB RAM, so i can't leverage GPU for inference. So, I did pure CPU version.
Maybe just need to pin orbax and chex to older versions as well?
I believe @fang2020shu has resolved this issue on their fork! https://github.com/fang2020shu/dalle-mini
There was an issue with orbax so I pinned it. This should resolve this issue.