alphafold
alphafold copied to clipboard
bugfix for latest jax version (jnp.take() in version jax>=0.3.8 return nans)
alphafold+multimer+templates returns NAN, starting with jax version 0.3.8 @YoshitakaMo traced it down to def batched_gather() in alphafold/model/utils.py
you need to change:
jnp.take(p, i, axis=axis)
to
jnp.take(p, i, axis=axis, mode="clip")
https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-8-april-29-2022 The reason this bug occurs is that starting in jax 0.3.8, jnp.take() now returns NaNs for indices that do not exist.
the alphafold pinned version of jax:
the output starting with jax=>0.3.8:
Thanks for this suggestion! This has been fixed in https://github.com/deepmind/alphafold/releases/tag/v2.2.4