openfold
openfold copied to clipboard
Fix periodicity batching bug for `supervised_chi_loss` in loss.py
- Fixes the periodicity batching bug for
supervised_chi_loss
discovered in https://github.com/aqlaboratory/openfold/issues/381. - Fixes the shape of
unnormalized_angles_sin_cos
for (batched) angle norm loss calculation. - Also corrects the docstring for this function.
- This corrected version of the function tested fine for me with both batched inputs e.g.,
angles_sin_cos=[batch_size, num_residues, 7, 2]
as well as single-example inputsangles_sin_cos=[num_residues, 7, 2]
.