Jax dependency Issue
Jax discontinued linear_util since v0.4.25 (latest) which means that when haiku is imported running jax@latest it crashes and since colabfold's pyproject.toml says its fine with any version of jax which is 0.4.20 above it causes issues running several methods of colabfold.
Please let me know if this makes sense or would you require more info.
Not sure how we can fix this. As far as I understand, we should not pin any jaxlib version, only jax, however that doesn't prevent it from installing a newer jaxlib.
Installing through poetry instead of pip should solve most issues, as it will install the versions specified in the lock file.
~I guess we have to wait for this to be merged:~ ~https://github.com/google-deepmind/dm-haiku/pull/739~
Nevermind, this was already merged, however we pin dm-haiku to 0.0.10 since I had some problem with 0.0.11, however I don't remember why.
I pinned dm-haiku 0.0.12, this should hopefully help avoid any issues once google decides to upgrade jax within colab.
Thanks!
I think we can update jax to the latest version and change the code to use the linear_utils from the jax.extend module find the same here, I can make a pr for the same, and maybe fix installs for jax on colab as well.
dm haiku 0.0.12 does this already. Should be fine now after updating the dependency
I think this is a related/the same error. I returned to a ColabFold v1.5.5: AlphaFold2 w/ MMseqs2 BATCH session that was working yesterday, and now I am getting the below error. Is there a fix for this? I might just be ignorant about what pinning dm-haiku 0.0.12 means.
RuntimeError Traceback (most recent call last)
in <cell line: 5>() 3 import sys 4 ----> 5 from colabfold.batch import get_queries, run 6 from colabfold.download import default_data_dir 7 from colabfold.utils import setup_logging 7 frames /usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version) 62 msg = (f'jaxlib is version {jaxlib_version}, but this version ' 63 f'of jax requires version >= {minimum_jaxlib_version}.') ---> 64 raise RuntimeError(msg) 65 66 if _jaxlib_version > _jax_version:
RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.20.
Are you using this notebook: https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/batch/AlphaFold2_batch.ipynb
The error message sounds like you are using an old version of it.
I tried running things from that notebook, and no change. I did realize that I only seem to get that bug when using TPU. GPU seems fine.
I pushed a fix for TPU, should work again
Thanks! Works great.
Since this has been fixed with the upgrade in dm haiku 0.0.12, I think I can close this issue?
I still need to make a new pip release, i updated the conda package a few days ago.