boltz icon indicating copy to clipboard operation
boltz copied to clipboard

feat: add AF-sample2 style random MSA masking (--mask_msa)

Open suzuki-2001 opened this issue 7 months ago • 5 comments

AF-sample2 MSA Masking

Congratulations on the Boltz-2 release — it’s a great milestone! Building on that success, this PR adds AF-sample2-style random MSA masking to help Boltz explore alternative conformations.

 

🚀 Why this PR matters

AF-sample2 demonstrated that randomly masking MSA columns (“AF-sample2 masking”) forces AlphaFold-2 to explore alternative conformations. On the OC23 benchmark it improved the non-preferred state in 9 / 23 targets (ΔTM > 0.05) while preserving the preferred state.

Boltz-1 is currently single-conformation; adding the same masking knob gives users a lightweight way to sample alternative complexes without touching the diffusion core.

AF-sample2  predicts multiple conformations and ensembles with AlphaFold2
Figure 1. Illustration of AF-sample2 MSA masking method (from AF-sample2 paper, Fig. 1a).

 

AFsample2 pipeline starts by generating MSAs for a given protein sequence. This is followed by randomized MSA masking (replacing a % of columns with X, denoting “unknown residue”) in a way such that a unique MSA profile is fed into the system at every instance of the inference run

 

📊 Experiment Results

On the OC23 set (23 two-state targets) 9 / 23 non-preferred states improved by ΔTM > 0.05 while the preferred state stayed intact (USalign). Note on MSAs All evaluations reuse the multiple-sequence alignments distributed with the original AF-sample2 repository.

Heat-map of Boltz-1 with random column masking
Figure 2. Boltz-1 + random column masking were evaluated on the OC23 benchmark. 9 of 23 targets improved over the default prediction (ΔTM-score > 0.05, measured with USalign). For each MSA we generated 10 diffusion samples and plotted the best TM-score obtained at every masking rate. Mask rate 0% is the baseline TM-score.

 

What’s in this PR

  • CLI flags

    • --mask_msa (bool, default False)
    • --mask_rate_msa (float 0–1, default 0.1)
    • --mask_seed_msa (int, default 42)
  • MSAS masking in MSAModule

    • trunk.py
    • trunkv2.py

After (optional) subsampling, build a mask, skip the query row, then set UNK (≡20) = 1.

# usage
boltz predict [input_fasta_or_yaml] --mask_msa --mask_rate_msa 0.1 --mask_seed_msa 42

 

🛠️ Dev Env

Item Spec / Version
GPU NVIDIA RTX 4080
CPU AMD Ryzen Threadripper PRO 7955WX (16 cores)
OS Linux 6.11.0-19-generic (rtx4080 host)
Python 3.10 (miniforge)
Install pip install -e . + pip install tensorboard

 

References

 

Welcome Feedback

We would greatly appreciate any feedback on the implementation details, design choices, or further ideas to improve this feature.

suzuki-2001 avatar Jun 06 '25 18:06 suzuki-2001

Thanks for this contribution! Is the confidence model able to pick out that a better structure was found or does this only work under oracle setting?

jwohlwend avatar Jun 08 '25 19:06 jwohlwend

@jwohlwend That's the key question, thank you for asking!

You are correct, the confidence model alone isn't a reliable way to pick out the best structure from the samples. For now, selecting the best candidate requires external validation, as we did in our benchmark (the oracle setting).

The main strength of this feature is generating a diverse set of high-quality conformations (≒ MSA subsample). We see improving the automated selection process as an important next step.

suzuki-2001 avatar Jun 08 '25 19:06 suzuki-2001

Cool, I’m good to merge with some documentation in prediction.md that also highlights this point! Would you be open to writing that?

jwohlwend avatar Jun 08 '25 20:06 jwohlwend

Great! I'll add the documentation and request your review once it's updated.

suzuki-2001 avatar Jun 08 '25 20:06 suzuki-2001

@jwohlwend Hi, I've updated the prediction.md

  • Added the new arguments
  • Included the note about the limitations of confidence scores
  • Made a cosmetic tweak to the table's Markdown source

Ready for another review when you have a moment. Thanks!

suzuki-2001 avatar Jun 09 '25 18:06 suzuki-2001

Hi @jwohlwend, This PR is now updated and ready for review.

  • The --mask_msa logic was adjusted to apply masking before one-hot encoding, ensuring consistency with how MSA masking operates at the raw token level.
  • The masking is now only applied if --mask_rate_msa > 0, simplifying CLI usage.
  • All conflicts have been resolved, and the feature was tested on version v2.1.1.

When testing with examples, I encountered a KeyError: 'profile_affinity' in boltz2.py:622. Setting affinity=False allowed me to verify mask_msa works correctly. This issue appears unrelated to this PR and may be better addressed separately.

Thanks again, and looking forward to your feedback..!

suzuki-2001 avatar Jun 23 '25 17:06 suzuki-2001