llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

Enable HF SpeedMonitor

Open rlrs opened this issue 1 year ago • 4 comments

Enable SpeedMonitor on HF models by using PyTorch FlopCounterMode to calculate model FLOPs.

rlrs avatar Feb 26 '24 18:02 rlrs

Oops, some of these changes are for our internal use. Will remove them from here.

rlrs avatar Feb 26 '24 18:02 rlrs

Hey @rlrs, thanks for the contribution! I didn't know about this PyTorch flop counter! We'll want to do a bit of testing to make sure that this reports the correct number and doesn't cause any issues with (1) speed (2) memory usage or (3) bad interactions with distributed training strategies like FSDP. What testing of this have you been able to do yourself?

dakinggg avatar Feb 26 '24 23:02 dakinggg

Apologies for the lack of explanation or tests, I rushed this a bit.

So far I've used this with Mistral 7B, comparing against the standard Transformer Math 6PD calculation, and the results are quite close - well, I also rely on one of the same assumptions, namely that the backward pass is 2x the forward pass. It is possible to wrap fwd+bwd in FlopCounterMode instead of just fwd. To me, that seems more complicated since that code has to live outside the HF model wrapper, from where the model FLOPs have to be returned.

One uncertainty I have is how the FLOP counter interacts with non-PyTorch constructs like Flash Attention. I suspect that it might be necessary to register such code manually in order to get the correct result. If so, it might be silently underreporting FLOPs right now.

rlrs avatar Feb 27 '24 10:02 rlrs

No worries @rlrs! If you're able to do some testing (and add some unit tests) that would be great! Otherwise we'll look into it when we get a chance and appreciate the suggestion!

dakinggg avatar Mar 01 '24 00:03 dakinggg