openfold
openfold copied to clipboard
Shape Mismatch with batch_size=2
Hi,
When I set batch_size: 2 in config.py , the following error occurs during execution:
File "/home/ganjh/data/openfold/openfold/model/model.py", line 568, in forward
outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
File "/home/ganjh/data/openfold/openfold/model/model.py", line 325, in iteration
template_embeds = self.embed_templates(
File "/home/ganjh/data/openfold/openfold/model/model.py", line 142, in embed_templates
template_embeds = self.template_embedder(
File "/home/ganjh/mambaforge/envs/openfold_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ganjh/data/openfold/openfold/model/embedders.py", line 943, in forward
pair_act = self.template_pair_embedder(
File "/home/ganjh/mambaforge/envs/openfold_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ganjh/data/openfold/openfold/model/embedders.py", line 791, in forward
pseudo_beta_mask_2d *= multichain_mask_2d
RuntimeError: output with shape [2, 1, 200, 200] doesn't match the broadcast shape [2, 2, 200, 200]
Also, I set the crop_size: 200 in config.py
Hi,
Have you found a fix to this? I am encountering this issue as well.