Type Error with af2_interface_metrics.py
Hi, First of all, I am really not that familiar with Jax. My conda environment was built with distributed yml file and thus got the Jax 0.3.24 as shown below. jax 0.3.24 pypi_0 pypi jaxlib 0.3.24 pypi_0 pypi
However, when running the af2_interface_metrics with the silent files, I am getting the following error. Any thoughts on this? I am getting the same error when using both AF 2.3.1 and AF 2.2.4 versions. Also, af2_metrics.py works without any issue.
Traceback (most recent call last):
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 597, in
Appreciate it a lot. Thanks!
okay, after reading a little bit of the model.py
prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params, jax.random.PRNGKey(0), processed_feature_dict, processed_initial_guess_dict)
module.AlphaFold is compiled by jax's jit, and the batch should be a dictionary with inputs to the AlphaFold model.

However, though _forward_fn() takes 1 positional argument, currently af2_interface_metrics.py line 431 passes 2 arguments?
Saying that since of course I am wrong as the tutorial was done with this code, I wonder what I am missing here. Again, I am pretty new for JAX and would appreciate if there is an idiot's 101 explanation for this. :)
TIA!
I didn't see that #48 #38 were basically the same issue as this.
see my reply in https://github.com/RosettaCommons/RFDesign/issues/48