RFDesign icon indicating copy to clipboard operation
RFDesign copied to clipboard

Type Error with af2_interface_metrics.py

Open aravinda1879 opened this issue 2 years ago • 3 comments

Hi, First of all, I am really not that familiar with Jax. My conda environment was built with distributed yml file and thus got the Jax 0.3.24 as shown below. jax 0.3.24 pypi_0 pypi jaxlib 0.3.24 pypi_0 pypi

However, when running the af2_interface_metrics with the silent files, I am getting the following error. Any thoughts on this? I am getting the same error when using both AF 2.3.1 and AF 2.2.4 versions. Also, af2_metrics.py works without any issue.

Traceback (most recent call last): File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 597, in predict_structure(tag_buffer, feature_dict_dict, binderlen_dict, initial_guess_dict, sfd_out, scorefilename) File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 431, in predict_structure prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params, File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, *args, **kwargs) File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, **kwargs) TypeError: _forward_fn() takes 1 positional argument but 2 were given

Appreciate it a lot. Thanks!

aravinda1879 avatar Mar 30 '23 12:03 aravinda1879

okay, after reading a little bit of the model.py

prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params, jax.random.PRNGKey(0), processed_feature_dict, processed_initial_guess_dict)

module.AlphaFold is compiled by jax's jit, and the batch should be a dictionary with inputs to the AlphaFold model. image

However, though _forward_fn() takes 1 positional argument, currently af2_interface_metrics.py line 431 passes 2 arguments? image Saying that since of course I am wrong as the tutorial was done with this code, I wonder what I am missing here. Again, I am pretty new for JAX and would appreciate if there is an idiot's 101 explanation for this. :) TIA!

aravinda1879 avatar Mar 30 '23 19:03 aravinda1879

I didn't see that #48 #38 were basically the same issue as this.

aravinda1879 avatar Mar 30 '23 20:03 aravinda1879

see my reply in https://github.com/RosettaCommons/RFDesign/issues/48

jueseph avatar Mar 30 '23 20:03 jueseph