alphafold icon indicating copy to clipboard operation
alphafold copied to clipboard

Pickle output result_model_1_multimer_v2_pred_0.pkl has a JAX dependency

Open tomgoddard opened this issue 1 year ago • 4 comments

In AlphaFold 2.2.4 the pickled per-structure output files from a multimer prediction such as

result_model_1_multimer_v2_pred_0.pkl

now contains a JAX dependency (apparently a jax DeviceArray structure was pickled). This prevents the .pkl data from being read by Python interpreters that do not have jax installed. This prevents accessing PAE data in the .pkl file when the file is moved to another machine for analysis where the Python does not have jax and jaxlib installed. This significantly reduces the portability of the files and appears like it was not intended. AlphaFold 2.2.0 multimer .pkl output does not have the JAX dependency.

Attempting to load the .pkl data without jax gives the following jax ModuleNotFoundError.

$ python3
Python 3.7.5 (default, Nov 26 2019, 14:12:06) 
[Clang 11.0.0 (clang-1100.0.33.12)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import pickle
>>> with open('result_model_1_multimer_v2_pred_0.pkl', 'rb') as f:
...   d = pickle.load(f)
... 
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
ModuleNotFoundError: No module named 'jax'

I suspect that adding a jax DeviceArray to the .pkl output was probably an accident added in AlphaFold 2.2.4 (or possibly in 2.2.3, 2.2.2 or 2.2.1). It would help analysis of the AlphaFold predictions to remove this jax data from the .pkl file.

I develop the ChimeraX molecular visualization which supports AlphaFold PAE error analysis which was broken by the added jax dependency for reading the .pkl files.

I'll attach an example "result_model_1_multimer_v2_pred_0.pkl" from AlphaFold 2.2.4 that exhibits the problem.

tomgoddard avatar Nov 16 '22 01:11 tomgoddard

you can fix this by installing jax[cpu] into your environment. At least that work with my visualizer

Intron7 avatar Nov 21 '22 09:11 Intron7

Yes installing jax and jaxlib can work. But this effects thousands of users of ChimeraX and we don't want to ship jax and jaxlib with our application in order to open a file. I'm not sure what jax _Device_Array got pickled but it is not the PAE data that ChimeraX is extracting from the file. I suspect pickling the _Device_Array was an accidental change in AlphaFold. I have hacked around the problem in our upcoming ChimeraX 1.5 release my making a fake jax module and trying again if loading the pickle file fails.

My purpose in reporting this is so that the pickle files can be more easily used by other developers, as they were in AlphaFold 2.2.0.

tomgoddard avatar Nov 21 '22 18:11 tomgoddard

Hi, I think what's get pickled is the distogram. I am also not sure, if it "depickles" properly.

print(metadata['distogram']) -> {'bin_edges': DeviceArray([ 2.3125 , 2.625 , 2.9375 , 3.25 , 3.5625 , 3.875 , 4.1875 , 4.5 , 4.8125 , 5.125 , 5.4375 , 5.75 , 6.0625 , 6.375 , 6.6875 , 7. , 7.3125 , 7.625 , 7.9375 , 8.25 , 8.5625 , 8.875 , 9.1875 , 9.5 , 9.812499, 10.125 , 10.4375 , 10.75 , 11.0625 , 11.375 , 11.6875 , 12. , 12.3125 , 12.625 , 12.9375 , 13.25 , 13.5625 , 13.875 , 14.1875 , 14.499999, 14.8125 , 15.125 , 15.4375 , 15.75 , 16.0625 , 16.375 , 16.6875 , 16.999998, 17.312498, 17.625 , 17.9375 , 18.25 , 18.5625 , 18.875 , 19.1875 , 19.5 , 19.8125 , 20.125 , 20.437498, 20.75 , 21.0625 , 21.375 , 21.6875 ], dtype=float32), 'logits': DeviceArray([[[ 7.99458694e+01, -1.86893845e+01, -2.46103077e+01, ..., -1.09830141e+00, 1.10635986e+01, 2.55225778e+00], [ 5.37824154e+00, 1.02621212e+01, -1.24734116e+00, ..., -8.33254433e+00, -6.89190483e+00, 7.95750475e+00], [ 1.54681420e+00, 3.07494116e+00, 3.35776448e-01, ..., -9.31102562e+00, -6.93648672e+00, 5.46216071e-01], ..., [-1.10818462e+01, -9.99354744e+00, -1.05438919e+01, ..., -1.65992379e+00, -2.05721903e+00, 4.57334995e+00], [-1.16069813e+01, -1.08642635e+01, -1.08378592e+01, ...,

PAE is unaffected

MatthiasZeug avatar Nov 29 '22 15:11 MatthiasZeug

Hi thanks for raising this. We will remove the jax dependency from this output and push an update.

Htomlinson14 avatar Jan 16 '23 16:01 Htomlinson14

Hi this was addressed in https://github.com/deepmind/alphafold/commit/91ac85ac72e1c217cf2d42fae200082c5c46a26f. Thanks for raising this.

Htomlinson14 avatar Feb 07 '23 10:02 Htomlinson14