openfold icon indicating copy to clipboard operation
openfold copied to clipboard

Incorrect (batched) einsum in `supervised_chi_loss()`?

Open amorehead opened this issue 1 year ago • 0 comments

Hello. Thank you all for making this work fully open-source.

I had a question about the supervised_chi_loss() function. When constructing chi_pi_periodic, shouldn't the einsum equation be ...ij,jk->...ik rather than ...ij,jk->ik to allow for the resulting tensor to (potentially) have a batch dimension (e.g., the first dimension) associated with it? Otherwise, I fail to see how the remaining code for this loss function will correctly account for the periodicity of different sequence inputs within a given batch (since these periodicities will likely vary from sequence to sequence within a particular batch).

Without this change, some local tests of mine show that the resulting tensor always has shape [num_residues, 4], and changing the effective batch size does not impact the shape of this tensor (implying that chi_pi_periodic is currently batch-agnostic).

https://github.com/aqlaboratory/openfold/blob/2dc080ce0bf83a9f90dfc75a799d754db68af104/openfold/utils/loss.py#L329

I think the reason this was never caught before (e.g., never threw an error) is because PyTorch automatically broadcasts the shape of shifted_mask to match that of true_chi by adding dummy dimensions to shifted_mask. Nonetheless, this would still lead to the resulting shifting logic being "batch-agnostic".

https://github.com/aqlaboratory/openfold/blob/2dc080ce0bf83a9f90dfc75a799d754db68af104/openfold/utils/loss.py#L337

amorehead avatar Dec 07 '23 23:12 amorehead