torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Adjust MFU to account for FP8

Open lessw2020 opened this issue 1 year ago • 24 comments

From internal discussions, logging an issue around updating our MFU calculations so that if FP8 is used, we can generate an accurate MFU number.

Atm - FP8 replaces wq/wk/wv/wo in Attention, and w1/w2/w3 in the MLP.

Thus, need an adjusted calculation.

In addition, would like to correctly pull the proper MFU (fp8 or bf16) based on the training config being run so this is handled automatically for the user.

lessw2020 avatar Aug 23 '24 22:08 lessw2020

yes, this needs to be updated. The MFU computations for fp8 are too good to be true :) CC: @lchu-ibm

raghukiran1224 avatar Sep 08 '24 11:09 raghukiran1224

Hmm I don't see a proper way to calculate MFU when some computations (linear) use FP8 and some others use BF16 (everything else, e.g. sdpa). What should be the right way?

tianyu-l avatar Nov 22 '24 00:11 tianyu-l

I chatted with @stas00 briefly offline and he suggested 2 approaches

  1. Assume peak is all computations are done in FP8, as this is no different than how MFU computations are done for BF16
  2. Assume peak is a weighted average of BF16 and FP8

1 feels simpler with the main con being that numbers we share might be "low" compared to other vendors if they use 2 but we can make it clear how we define MFU in high level marketing material

msaroufim avatar Dec 04 '24 17:12 msaroufim

My personal thoughts:

I was surprised when I saw that torchtitan uses the simple and overoptimistic "academic" flops formula (https://github.com/pytorch/torchtitan/blob/main/torchtitan/utils.py#L231) considering that torch.utils.flop_counter.FlopCounterMode already exists (and in my experience, works quite well).

I feel like PyTorch could offer a very similar utility that measures peak flops per operator. Imagine that you prepare your model first (parallelize, fp8, etc) and then measure the peak flops. From my understanding, such utility could understand the differences between the theoretical performance of the operators based on runtime information like their dtypes.

But if you want to go the easy route, option 1 seems the better choice to me. For the big models most flops will go to the linears anyways and forkers can always customize the calculation.

But please do not perpetuate the abuse of MFU as a marketing number 🙏. You can report tokens per second (if not padding) or sequences per second. MFU measuring contests are pointless unless the flops calculation is proved to be exactly the same

carmocca avatar Dec 05 '24 01:12 carmocca

thanks @carmocca

  1. To me there is no right way of doing mixed MFU with BF16 and FP8, the approaches @msaroufim mentioned sound "less wrong" but not correct.
  2. Fine-grained MFU e.g. at op level makes sense, but at the cost that it is no longer a top-line metric for e2e systems.
  3. I agree that we should focus on metrics like tokens/sec to make fair comparisons (but even for that different repos have different ways of measuring). This is what we report in the torchtitan paper.
  4. We still keep MFU around because it's more intuitive than tokens/sec. E.g. when using Context Parallel with ultra long sequences, the throughput would drop a lot naturally, but we can look at MFU ~30%--40% to have confidence that the implementation is good.
  5. I guess the way torchtitan calculates MFU was unfortunately influenced by other popular repos. I hope the community can agree on the metrics and their calculation.

tianyu-l avatar Dec 05 '24 21:12 tianyu-l

added TFLOPs in https://github.com/pytorch/torchtitan/pull/847 as alternative -- closing this issue

tianyu-l avatar Feb 26 '25 06:02 tianyu-l

Will this ever be revisited? Currently seeing some bogus MFU numbers like 87% when using fp8.

chelsea0x3b avatar Nov 12 '25 17:11 chelsea0x3b

MFU when using mixed fp8 and bf16 is not well-defined. People use token/sec or tflops instead.

tianyu-l avatar Nov 12 '25 21:11 tianyu-l

Is torchtitan the place where a new standard could be born? Even just having separate MFU numbers reported for different dtypes could be more clear

chelsea0x3b avatar Nov 12 '25 21:11 chelsea0x3b

Perhaps going forward it'd be the easiest to report tflops not as a single number? but something like 415(bf16) 490(fp8) - then fp4, mxfp4, etc. can be added as well - whatever the different dtypes are used - report each separately. Then everything becomes well-defined. And MFU as well 43%(bf16) 45%(fp8).

stas00 avatar Nov 12 '25 21:11 stas00

@coreylowman @stas00

My take is that MFU was used as a "top-line" "logical" metric, whose denominator takes into account the time we spend on data loading, recomputation, etc. If separating computation per dtype / per op, yes we could count per-op / per-dtype utility rate, but how do you attribute the "general waste"?

So I'd say

  • If current MFU is confusing and useless, we should consider deleting it. But then we lose it for bf16 monitoring, should we enable it conditionally? but what if I just wanna get a sense of how fast fp8 is rather than taking the MFU literally. One could argue that we should use tflops / tps, which is what I mentioned earlier and what I'm seeing the community / papers / tech reports are using.
  • If users want fine-grained metrics measuring something else, proposals and discussions are welcome. But to me it sounds like adding an entirely new metric to track (e.g. just like grad_norm), rather than as a replacement of MFU.

tianyu-l avatar Nov 12 '25 21:11 tianyu-l

I personally don't use MFU as it's already a BS number since 100% is unachievable and moreover the achievable efficiency wildly varies between gpus. So if you move from B200 to GB200 you already have to recalibrate since the achievable performance/effiency isn't the same (see the tables I linked to).

Relative TFLOPS are very useful because they tell you how much improvement or regression has been made against some previous baseline. But when you run different parts of compute in different regime a single TFLOPS number is not possible. Even if you somehow do a weighted average the important information will be lost in averaging. So as I mentioned above reporting TFLOPS per dtype (for all ops of that dtype) is probably the next common sense practical thing.

If people then want to convert this to MFU so that they could boast about their system's efficiency it's totally their choice.

I did like MFU as an indicator of the efficiency of the compute before I created MAMF metrics. So for example Megatron-LM reported that when they moved from A100 to H100 back in the day they couldn't reach the same MFU - not because their code was somehow inferior (it was super optimized), but because the H100's compute efficiency was much worse than that of H100 (and the trend continues btw if you check the table in the link of the first para). So that's where I'd say it's useful still. But again not as a comparison across frameworks. The latter is wrong because people can't even agree on how to measure FLOPs (in particular various fwd re-runs and for example Flash Attention flops don't fit into any known approach, Tri Dao himself shares some very unusual multipliers because he wrote this stuff - what are the chances that other developers will calculate FLOPs correctly.

My recommendation is this - measure FLOPs in whatever way you like and compare your previous code version's FLOPs/s to your current code version's of the same. Then you know if you make a regression or you're making things better. If it's a public framework clearly disclose how you measure FLOPs for different components.

stas00 avatar Nov 12 '25 22:11 stas00

I think you could have a total time taken metric, and then do metrics per module/dtype. Then overhead/general waste is computed by total time - sum(all module times). But I don’t think you could really get waste per dtype, which you’d need if you wanted a good mfu per dtype. 🤔 I think in the meantime, just disabling MFU if float8 converters are activate could be a good temporary bandaid. I’ll keep thinking about this. It’s something the field definitely needs to figure out, especially with fp4 also coming into the mix

chelsea0x3b avatar Nov 12 '25 22:11 chelsea0x3b

(copying from X)

dtypes are a numerical choice on the frontend, MFU measures the implementation of the backend. it should be well defined as sum_dtype(flops achieved) / sum_dtype(theoretical flops) right?

50TF bf16 for half the model and 50TF fp8 for the other half @ 100TF, 200TF respectively would be 33% MFU

bwasti avatar Nov 12 '25 22:11 bwasti

yup, but as I suggested above it's probably best not to average since an important signal will be lost if you have one of the dtypes more inferior implementation than the other part. Or perhaps report the weighted average plus the break down, then the user has a choice on the level of detail they want.

stas00 avatar Nov 12 '25 22:11 stas00

I'd assume if a user really doesn't know how much each of their dtypes run, they probably don't care about MFU (or know what to do with that number). that being said it would be useful to report a breakdown for users that do.

calling it a weighted average is a bit deceptive since it's an actual count of the number of flops achieved divided by the theoretical peak that could be achieved. (weighted average is the solution to this, but not the intuition)

bwasti avatar Nov 12 '25 22:11 bwasti

Yes, that's why I'm suggesting a breakdown report.

Users should care a lot about reported TFLOPs/s and try to improve those. If they don't it will cost them $$ and lost time.

stas00 avatar Nov 12 '25 22:11 stas00

Sorry ya'll didn't realize email replies were formatted that badly lol.

+1 on breakdown report by dtype +1 on MFU being a bad metric and not even usable across different GPU types +1 on tflops being better due to not using the magical peak flops number +1 on tflops still not being clear enough due to vague-ness of model flops

IDK I kind of just prefer tokens/s/device to everything else.

BUT regardless of opinions on the above, I don't think the inaccurate MFU number should be shown if we know its inaccurate, so I still stand by temporarily disabling showing MFU if fp8 is used.

chelsea0x3b avatar Nov 12 '25 23:11 chelsea0x3b

tokens/s is also a very vague metric other than for local relative comparisons and even then one has to be very careful - this number alone is meaningless. Due to quadratic nature of attention the more tokens you have in a contiguous sample the more "efficient" the system becomes because it moves towards being compute-bound. So when you report tokens/s - did you use 1k token or 1M token long samples? With the same system you'd get very different results depending on the sample size. (and take into consideration packed samples vs full samples, since self-attention will be very different compute-wise)

Have a look at https://arxiv.org/abs/2506.13996 - with 15M-token-long samples - all your overheads are irrelevant since the model will take a very very long time to compute 15M token attention, but the tflops as you can see from reports are very high!

stas00 avatar Nov 12 '25 23:11 stas00

I don’t see what’s so wrong with the weighted average approach - this seems like the most intuitive option to me.

Is the “information loss” you’re concerned about with this approach just more granular per-dtype information that could help highlight performance issues with a fp8 implementation, for example?

If so, this is an important but secondary concern IMO. As a first step, I think we could use weighted average to have a more accurate mfu metric for fp8 training with mixed precision. When this mfu metric is low, the user can do some profiling and examine traces to determine where the issue is.

danielvegamyhre avatar Nov 13 '25 03:11 danielvegamyhre

@bwasti @danielvegamyhre

Please educate me about the GPU basics here

50TF bf16 for half the model and 50TF fp8 for the other half @ 100TF, 200TF respectively would be 33% MFU

Would one be carrying out bf16 computation and fp8 computation "at the same time"?

IIUC the denominator is assuming peak flops is achieved when both bf16 and fp8 cores are saturated, so if a model training happens in bf16 can hypothetically achieve 100% MFU (in traditional sense), with this new definition its averaged MFU' is capped at 100 / 300 = 33%. Am I getting it right?

Or do you mean averaging the peak flops in the denominator as well, like (100 + 200) / 2? Should this also be weighted average, and if so how should we weight each?

tianyu-l avatar Nov 13 '25 03:11 tianyu-l

Assuming only one type can run at at time MFU can be defined as: MFU = 100 * sum_i(model_flops[i] / peak_flops[i]) / time; Where i iterates over data types and peak_flops[i] are the peak flops per second for the i-th data type and model_flops[i] are the corresponding data type flops in the model. time is the total time model took to run. Conceptually the nominator is the ideal time it would take at 100% utilization and the denominator is the actual time it took to compute. It gets more complex when different data types can run in parallel on the HW, with max() instead of sum(). MFU is still a useful metric, even if 100% is not achievable in practice.

v-arbatov avatar Nov 13 '25 05:11 v-arbatov

@v-arbatov this sounds unifying traditional bf16-only MFU

However, adding percentages together loses the signal of comparison between bf16 and fp8. E.g. how do you compare [50% pure bf16] vs. [40% pure fp8], vs. [45% mixed]? It doesn't sound better than a per-dtype MFU (aka without the summation).

But yeah, according to what people say, let's

  1. in short-term, remove MFU when fp8 is enabled
  2. in longer term, try to report per-dtype "MFU", but it needs some refactor of code

cc @fegin @wwwjn @shuhuayu

tianyu-l avatar Nov 14 '25 07:11 tianyu-l

@tianyu-l , it doesn't add percentages together, it computes one single utilization percentage against the ideal 100% utilization across all units without measuring latency of individual data types. Utilization [50% pure bf16] > [45% mixed] > [40% pure fp8]. Please note that the utilization highlights only efficiency of the whole model implementation. [45% mixed] may still run much faster than [50% pure bf16]. Also, say, in the model, bf16 computations have 90% utilization and fp8 parts only at 50%. If the fp8 parts of the model are sufficiently large, they will drag overall MFU down, even when model runs faster. MFU is just one of efficiency metrics and not perfect as others noted.

v-arbatov avatar Nov 14 '25 08:11 v-arbatov