torchtitan
torchtitan copied to clipboard
exclude embedding in MFU computation
Stack from ghstack (oldest at bottom):
- -> #280
Per suggestion in #274:
This PR removes embedding from number of parameters calculation, because embedding op doesn't do matmul.
This PR keeps a factor of 12 in the self-attention, because dealing with sparsity (causal attention) is very tricky and for now let's follow the convention (PaLM paper, nanoGPT, etc.)
FlashAttention makes use of the causal mask to do half the work, so one of my friends got >100% MFU when using the 12 factor rather than 7. Common options for multipliers in the causal setting are:
6: This is theoretically the lowest FLOP count possible, but it's not what an efficient implementation would use. I only know of one implementation which calculates it this way. 7: This is Flash Attention, and how they benchmark themselves. This is the most accurate counter to engineer against. 12: This is common, but not very useful in my experience, as it's easy to make the MFU go way up by changing context length.
@ad8e imo, I would either use 6 or 12. MFU was originally intended to exclude recomputation flops (from activation checkpointing), it seems somewhat strange to me to reinclude it here. In addition, other FlashAttention implementations (like say, Triton's) actually end up with a factor of 9 for FLOPs.
My argument for 12 would be that, if you use 6, then you need to start to be quite consistent about taking into account sparsity (for example, let's say we add sliding window attention). Otherwise, you end up back in the same situation you're referring to, where increasing sequence length results in overly increased MFU.
Perhaps unobviously, flop counting is a somewhat subjective enterprise. For FlashAttentionv2 itself it makes sense to benchmark with 7, as what it cares about is "how much room is there to optimize this kernel".
I think, the question is whether you care more about being "invariant" across sequence lengths or "attention patterns". I would probably agree that sequence length is the more important factor, so you've convinced me it should be 6 :)
@ad8e
I checked some other mainstream repos on how MFU is computed. From what I can tell, most (if not all) of them are using 12. For example:
- nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py#L296
- Megatron: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training/training.py#L79
- MosaicML: https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/benchmarking/collect_results.py#L162
Since ultimately MFU is a derived metric from token-per-second (formula) and there doesn't exist a consensus on what formula to use, we feel it's safer to follow the industry convention to use 12, unless the community change it together. One can always use token-per-second for more direct comparisons.
I checked some other mainstream repos on how MFU is computed. From what I can tell, most (if not all) of them are using 12. For example:
If you wish to base your decision on what is most widely used, I agree that 12 is the most common number. Only Flash Attention uses 7 (and me), and only one codebase I've seen uses 6.