Added a --use_cuda_bfloat16 option to the boltz predict command.
Fixes #263. This allows predicting structures with 40% more residues without running out of GPU memory. I tested on a half-dozen PDB structures and did not see any loss of accuracy.
There is a spelling error, "use_cuda_bloat16: bool = False"
Thanks for pointing out the typo. I have fixed it in this pull request. I fixed it a week ago but forgot to update this fork because I decided to make another boltz fork specifically for the version of Boltz to use in ChimeraX and did my testing and fix on that other fork and forgot to fix this one.
Boltz-2 now runs with bfloat where possible. I think it might make sense to keep Boltz always in float32 without more comprehensive evaluation. I will try to test that and get back to you!
Where are the lines of code in boltz2 that allow autocast to work? Doing a git grep shows all calls to torch.autocast have enabled=False.
$ git rev-parse HEAD
744b4aecb6b5e847a25692ced07c328e7995ee33
$ git grep -h autocast | sed -e 's/^ *//' | sort | uniq -c
1 # Fixes an issue with unused parameters in autocast
1 # Issue with unused parameters in autocast
2 and torch.is_autocast_enabled()
1 fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
1 fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
2 torch.clear_autocast_cache()
1 with torch.amp.autocast("cuda", enabled=False):
19 with torch.autocast("cuda", enabled=False):
2 with torch.autocast(device_type="cuda", enabled=False):
3 with torch.cuda.amp.autocast(enabled=False):
1 with torch.no_grad(), torch.autocast("cuda", enabled=False):
In Boltz 2 the use of bfloat16 is set by this line in the predict() function in main.py
precision=32 if model == "boltz1" else "bf16-mixed",
and the stderr output of the Boltz 2 boltz predict command includes the line
Using bfloat16 Automatic Mixed Precision (AMP)
This gives faster predictions than boltz1 on 8 predictions ranging in size from 100 to 1400 tokens on an Nvidia RTX 4090, and about the same speed as boltz1 predictions with the boltz1 use_cuda_bfloat16 option described in this issue. So it looks like the "bf16-mixed" precision used by Boltz 2 has about the equivalent effect of the --use_cuda_bfloat16 option proposed above. I wonder if the use of bfloat16 with the many lines that make specific calculations stay in float32
torch.amp.autocast("cuda", enabled=False)
has been tested sufficiently to know that it does not degrade accuracy of results. That was my main reason for suggesting a --use_cuda_bfloat16 option so the user could decide if the potential trade-off of accuracy versus speed and maximum structure size was worthwhile. If the the developers are confident that bfloat16 is not degrading accuracy then it would be fine to not provide an option for choosing precision=32 versus precision = "bf16-mixed".
The current Boltz 2 "bf16-mixed" precision should be changed to use precision=32 for CPU predictions and Mac GPU predictions (using MPS = metal performance shaders). The "bf16-mixed" makes those predictions 3 to 8 times slower because only CUDA has efficient bfloat16 support in torch. That is why the above proposed --use_cuda_bfloat16 was limited to predictions that use CUDA.
Hi,
I do not see any difference in performance between bfloat16 and 32 on my Mac using mps? But I modified torch.amp.autocast("cuda", enabled=False) to device-agnostic code in my fork : torch.autocast(device_type,enabled=False)
A prediction for a 300-residue protein and a ligand runs in about 1 min on my M1 Max with 64 Go
Again I don't think it makes sense to enable this without extensive testing. Please use Boltz-2 which already uses bf16 where possible. Closing for now!
Makes sense to close this. As noted in the above comment https://github.com/jwohlwend/boltz/pull/264#issuecomment-2960942079 Boltz 2 is using bfloat16 for some operations because it sets precision = "bf16-mixed" and this seems to be giving comparable speeds to Boltz 1 with the proposed bfloat16 option with CUDA. The current Boltz 2 "b16-mixed" precision leads to slower computation speeds (by a factor of 3 or more) than precision = 32 for CPU runs without CUDA.
Hmm interesting about the CPU.. Maybe we should default to 32 on CPU?
Hmm interesting about the CPU.. Maybe we should default to 32 on CPU?
I think precision should be bf16-mixed only when CUDA is used, and 32 in all other cases. In my tests on Intel CPUs and Mac ARM GPU and CPU it always slowed the calculations (dramatically on Intel Core CPUs, modestly on Mac) and did not reduce memory use significantly. I suspect that is because Intel and Mac ARM don't have hardware support for it in Torch so it just spends more time converting to float32 any place it would need to use bfloat16.