feat: add AF-sample2 style random MSA masking (--mask_msa)
AF-sample2 MSA Masking
Congratulations on the Boltz-2 release — it’s a great milestone! Building on that success, this PR adds AF-sample2-style random MSA masking to help Boltz explore alternative conformations.
🚀 Why this PR matters
AF-sample2 demonstrated that randomly masking MSA columns (“AF-sample2 masking”) forces AlphaFold-2 to explore alternative conformations. On the OC23 benchmark it improved the non-preferred state in 9 / 23 targets (ΔTM > 0.05) while preserving the preferred state.
Boltz-1 is currently single-conformation; adding the same masking knob gives users a lightweight way to sample alternative complexes without touching the diffusion core.
AFsample2 pipeline starts by generating MSAs for a given protein sequence. This is followed by randomized MSA masking (replacing a % of columns with X, denoting “unknown residue”) in a way such that a unique MSA profile is fed into the system at every instance of the inference run
📊 Experiment Results
On the OC23 set (23 two-state targets) 9 / 23 non-preferred states improved by ΔTM > 0.05 while the preferred state stayed intact (USalign). Note on MSAs All evaluations reuse the multiple-sequence alignments distributed with the original AF-sample2 repository.
What’s in this PR
-
CLI flags
--mask_msa(bool, default False)--mask_rate_msa(float 0–1, default 0.1)--mask_seed_msa(int, default 42)
-
MSAS masking in MSAModule
trunk.pytrunkv2.py
After (optional) subsampling, build a mask, skip the query row, then set UNK (≡20) = 1.
# usage
boltz predict [input_fasta_or_yaml] --mask_msa --mask_rate_msa 0.1 --mask_seed_msa 42
🛠️ Dev Env
| Item | Spec / Version |
|---|---|
| GPU | NVIDIA RTX 4080 |
| CPU | AMD Ryzen Threadripper PRO 7955WX (16 cores) |
| OS | Linux 6.11.0-19-generic (rtx4080 host) |
| Python | 3.10 (miniforge) |
| Install | pip install -e . + pip install tensorboard |
References
Welcome Feedback
We would greatly appreciate any feedback on the implementation details, design choices, or further ideas to improve this feature.
Thanks for this contribution! Is the confidence model able to pick out that a better structure was found or does this only work under oracle setting?
@jwohlwend That's the key question, thank you for asking!
You are correct, the confidence model alone isn't a reliable way to pick out the best structure from the samples. For now, selecting the best candidate requires external validation, as we did in our benchmark (the oracle setting).
The main strength of this feature is generating a diverse set of high-quality conformations (≒ MSA subsample). We see improving the automated selection process as an important next step.
Cool, I’m good to merge with some documentation in prediction.md that also highlights this point! Would you be open to writing that?
Great! I'll add the documentation and request your review once it's updated.
@jwohlwend Hi, I've updated the prediction.md
- Added the new arguments
- Included the note about the limitations of confidence scores
- Made a cosmetic tweak to the table's Markdown source
Ready for another review when you have a moment. Thanks!
Hi @jwohlwend, This PR is now updated and ready for review.
- The
--mask_msalogic was adjusted to apply masking before one-hot encoding, ensuring consistency with how MSA masking operates at the raw token level. - The masking is now only applied if
--mask_rate_msa > 0, simplifying CLI usage. - All conflicts have been resolved, and the feature was tested on version
v2.1.1.
When testing with examples, I encountered a KeyError: 'profile_affinity' in boltz2.py:622. Setting affinity=False allowed me to verify mask_msa works correctly. This issue appears unrelated to this PR and may be better addressed separately.
Thanks again, and looking forward to your feedback..!