llm-foundry
llm-foundry copied to clipboard
Enable HF SpeedMonitor
Enable SpeedMonitor on HF models by using PyTorch FlopCounterMode to calculate model FLOPs.
Oops, some of these changes are for our internal use. Will remove them from here.
Hey @rlrs, thanks for the contribution! I didn't know about this PyTorch flop counter! We'll want to do a bit of testing to make sure that this reports the correct number and doesn't cause any issues with (1) speed (2) memory usage or (3) bad interactions with distributed training strategies like FSDP. What testing of this have you been able to do yourself?
Apologies for the lack of explanation or tests, I rushed this a bit.
So far I've used this with Mistral 7B, comparing against the standard Transformer Math 6PD
calculation, and the results are quite close - well, I also rely on one of the same assumptions, namely that the backward pass is 2x the forward pass. It is possible to wrap fwd+bwd in FlopCounterMode instead of just fwd. To me, that seems more complicated since that code has to live outside the HF model wrapper, from where the model FLOPs have to be returned.
One uncertainty I have is how the FLOP counter interacts with non-PyTorch constructs like Flash Attention. I suspect that it might be necessary to register such code manually in order to get the correct result. If so, it might be silently underreporting FLOPs right now.
No worries @rlrs! If you're able to do some testing (and add some unit tests) that would be great! Otherwise we'll look into it when we get a chance and appreciate the suggestion!