openfold
openfold copied to clipboard
Bug in interface detection?
Why isn't there an additional broadcast dimension over the atom indices? Won't this code only calculate pairwise distances between the same atom types in different residues instead of all pairwise distances between two residues?
https://github.com/aqlaboratory/openfold/blob/e938c184a291bf053af3b14c1e3e8bb29aee57e2/openfold/data/data_transforms_multimer.py#L320
The following code broadcasts over both the sequence length and atoms:
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, None, :, :] - positions[..., None, :, :, None, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, None, :] * atom_mask[..., None, :, :, None]
mask = (diff_chain_mask[..., None, None] * pair_mask).bool()
pairwise_dists = torch.where(mask, pairwise_dists, torch.inf)
min_dist_per_res = pairwise_dists.amin(dim=(-1, -2))
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
return interface_residues_idxs