openfold
openfold copied to clipboard
fp32 vs bf16 training performance?
I train the openfold in both fp32 and bf16, it seems bf16 performs better than fp32 or the fp32 config is not properly set by me?
The command for bf16 is
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/
2021-10-10 \
--template_release_dates_cache_path mmcif_cache.json \
--precision bf16
--gpus 8 --replace_sampler_ddp=True
--seed 4242022 \ # in multi-gpu settings, the seed must be specified
--deepspeed_config_path deepspeed_config.json
--checkpoint_every_epoch
--resume_from_ckpt ckpt_dir/
--train_chain_data_cache_path chain_data_cache.json
--obsolete_pdbs_file_path obsolete.dat
The command for fp32 is
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/
2021-10-10 \
--template_release_dates_cache_path mmcif_cache.json \
--precision 32
--gpus 8 --replace_sampler_ddp=True
--seed 4242022 \ # in multi-gpu settings, the seed must be specified
--deepspeed_config_path deepspeed_config.json
--checkpoint_every_epoch
--resume_from_ckpt ckpt_dir/
--train_chain_data_cache_path chain_data_cache.json
--obsolete_pdbs_file_path obsolete.dat
and enable the bfloat16 flag in 'deepspeed_config.json'

OF training seems to have two different modes, which you've observed here. I don't think this has anything to do with precision, because other people using OF have gotten stuck in the lower mode while using bfloat16 and successfully converged with fp32. I suspect that if you fiddle with the seed you'll be able to induce the higher mode in the fp32 model too.
For what it's worth, runs that get stuck in the lower mode eventually undergo what I can only describe as a "phase transition" and rocket up to the accuracy of the normal runs. When that phase transition occurs seems to vary dramatically, as one might expect, and I've seen it happen anywhere from 7k global steps to about 20k (using a batch size of 120). It does occur very consistently, however: I've only seen a single run get permanently stuck on the lower lDDT-Ca plateau.
What you've shared here is good data. I've only previously seen this bimodal phenomenon in runs that also use the self-distillation set. In that case, the two modes seem to be at ~0.81 and ~0.35, rather than the ~0.7 and ~0.4 you're getting here. Granted, that discrepancy might also have to do with the fact that you're using a very small effective batch size of 8 (in the standard AF training setup it's ~120).
Unfortunately, I've never seen the lower mode in my personal training setup, so I don't think it's inevitable, but I have yet to isolate exactly what's causing it. If anyone has the insight/compute to figure out what's happening here, that would be very much appreciated.
OF training seems to have two different modes, which you've observed here. I don't think this has anything to do with precision, because other people using OF have gotten stuck in the lower mode while using bfloat16 and successfully converged with fp32. I suspect that if you fiddle with the seed you'll be able to induce the higher mode in the fp32 model too.
For what it's worth, runs that get stuck in the lower mode eventually undergo what I can only describe as a "phase transition" and rocket up to the accuracy of the normal runs. When that phase transition occurs seems to vary dramatically, as one might expect, and I've seen it happen anywhere from 7k global steps to about 20k (using a batch size of 120). It does occur very consistently, however: I've only seen a single run get permanently stuck on the lower lDDT-Ca plateau.
What you've shared here is good data. I've only previously seen this bimodal phenomenon in runs that also use the self-distillation set. In that case, the two modes seem to be at ~0.81 and ~0.35, rather than the ~0.7 and ~0.4 you're getting here. Granted, that discrepancy might also have to do with the fact that you're using a very small effective batch size of 8 (in the standard AF training setup it's ~120).
Unfortunately, I've never seen the lower mode in my personal training setup, so I don't think it's inevitable, but I have yet to isolate exactly what's causing it. If anyone has the insight/compute to figure out what's happening here, that would be very much appreciated.
Thanks for your reply, i will try to run the fp32 model again and check the result. The commit id for my previous training is 3fb643e92d6f80a468a8fd2a41a8808c369074c9. The random seed for fp32 and bfloat16 model is also same, the only difference is that the bfloat16["enabled"] flag set true for bf16 and false for fp32.