Add command-line chunking parameters for flexible VRAM utilization
Several issues here and on the slack are related to VRAM limitations: #71, #83, #106, #167, #169, #197, #214
By tweaking the chunking parameters, you can significantly reduce the VRAM requirements. Using this strategy, I have modelled systems with 4000+ tokens on our H200s.
This PR:
- Provides chunking parameters as command-line arguments.
- Adds chunking to the PairformerLayer transition_z calculation (same as MSAlayer) which becomes a bottleneck at large system sizes.
Default behaviour should be consistent with the existing implementation.
- I did also have to enable PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True which may be worth documenting.
- Chunking will only occur above the chunk_size_threshold which may also be a useful command line argument.
@tlitfin-unsw thanks the initiative we faced similar issues internally and did a very similar change
like :)
I would love to see this integrated and released as I'm currently limited by the linked issues.
Another way to increase the maximum size of predicted structures with a given amount of VRAM is to use 16-bit floating point (bfloat16) using Nvidia GPUs as described in #263. In limited tests this allowed predicting 40% more residues without showing noticeable decreased accuracy versus using the default 32-bit floating point. Pull request #264 adds an option to the boltz command to use bfloat16. I would love to see improvements to boltz to address the structure size limits which are much smaller than AlphaFold 3 and I think are the main limitation for research biologists, most of whom have consumer GPUs with limited memory (e.g. Nvidia 4090 with 24 GB) and not H200 GPUs with 141 GB.
Closing this stale PR. I implemented these changes and added several other inference-only memory optimizations in this fork.