ColabFold icon indicating copy to clipboard operation
ColabFold copied to clipboard

Jax dependency Issue

Open Aradhya-Tripathi opened this issue 2 years ago • 12 comments

Jax discontinued linear_util since v0.4.25 (latest) which means that when haiku is imported running jax@latest it crashes and since colabfold's pyproject.toml says its fine with any version of jax which is 0.4.20 above it causes issues running several methods of colabfold.

Please let me know if this makes sense or would you require more info.

Aradhya-Tripathi avatar Feb 29 '24 00:02 Aradhya-Tripathi

Not sure how we can fix this. As far as I understand, we should not pin any jaxlib version, only jax, however that doesn't prevent it from installing a newer jaxlib.

Installing through poetry instead of pip should solve most issues, as it will install the versions specified in the lock file.

milot-mirdita avatar Mar 07 '24 06:03 milot-mirdita

~I guess we have to wait for this to be merged:~ ~https://github.com/google-deepmind/dm-haiku/pull/739~

Nevermind, this was already merged, however we pin dm-haiku to 0.0.10 since I had some problem with 0.0.11, however I don't remember why.

milot-mirdita avatar Mar 07 '24 06:03 milot-mirdita

I pinned dm-haiku 0.0.12, this should hopefully help avoid any issues once google decides to upgrade jax within colab.

Thanks!

milot-mirdita avatar Mar 07 '24 06:03 milot-mirdita

I think we can update jax to the latest version and change the code to use the linear_utils from the jax.extend module find the same here, I can make a pr for the same, and maybe fix installs for jax on colab as well.

Aradhya-Tripathi avatar Mar 07 '24 09:03 Aradhya-Tripathi

dm haiku 0.0.12 does this already. Should be fine now after updating the dependency

milot-mirdita avatar Mar 07 '24 09:03 milot-mirdita

I think this is a related/the same error. I returned to a ColabFold v1.5.5: AlphaFold2 w/ MMseqs2 BATCH session that was working yesterday, and now I am getting the below error. Is there a fix for this? I might just be ignorant about what pinning dm-haiku 0.0.12 means.

RuntimeError Traceback (most recent call last) in <cell line: 5>() 3 import sys 4 ----> 5 from colabfold.batch import get_queries, run 6 from colabfold.download import default_data_dir 7 from colabfold.utils import setup_logging

7 frames /usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version) 62 msg = (f'jaxlib is version {jaxlib_version}, but this version ' 63 f'of jax requires version >= {minimum_jaxlib_version}.') ---> 64 raise RuntimeError(msg) 65 66 if _jaxlib_version > _jax_version:

RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.20.

johnjacobpeters avatar Mar 07 '24 17:03 johnjacobpeters

Are you using this notebook: https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/batch/AlphaFold2_batch.ipynb

The error message sounds like you are using an old version of it.

milot-mirdita avatar Mar 07 '24 17:03 milot-mirdita

I tried running things from that notebook, and no change. I did realize that I only seem to get that bug when using TPU. GPU seems fine.

johnjacobpeters avatar Mar 07 '24 22:03 johnjacobpeters

I pushed a fix for TPU, should work again

milot-mirdita avatar Mar 08 '24 02:03 milot-mirdita

Thanks! Works great.

johnjacobpeters avatar Mar 08 '24 03:03 johnjacobpeters

Since this has been fixed with the upgrade in dm haiku 0.0.12, I think I can close this issue?

Aradhya-Tripathi avatar Mar 12 '24 13:03 Aradhya-Tripathi

I still need to make a new pip release, i updated the conda package a few days ago.

milot-mirdita avatar Mar 12 '24 13:03 milot-mirdita