alphafold icon indicating copy to clipboard operation
alphafold copied to clipboard

bugfix for latest jax version (jnp.take() in version jax>=0.3.8 return nans)

Open sokrypton opened this issue 2 years ago • 1 comments

alphafold+multimer+templates returns NAN, starting with jax version 0.3.8 @YoshitakaMo traced it down to def batched_gather() in alphafold/model/utils.py

you need to change: jnp.take(p, i, axis=axis) to jnp.take(p, i, axis=axis, mode="clip")

https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-8-april-29-2022 The reason this bug occurs is that starting in jax 0.3.8, jnp.take() now returns NaNs for indices that do not exist.

the alphafold pinned version of jax: image

the output starting with jax=>0.3.8: image

sokrypton avatar Jun 17 '22 05:06 sokrypton

Thanks for this suggestion! This has been fixed in https://github.com/deepmind/alphafold/releases/tag/v2.2.4

Htomlinson14 avatar Sep 21 '22 16:09 Htomlinson14