maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

MFU drops significantly when using megablox with more experts

Open rodrigo-f-nogueira opened this issue 11 months ago • 4 comments

I'm testing Mixtral-8x7B without attention so I can isolate the effects of the MoE layer.

When num_experts=8 and num_experts_per_token=2, MFU on a v5p-64 is 50.4%, which is good.

However, I wanted to test an architecture that is more similar to DeepSeek's, which uses more experts.

Thus, I increased the number of experts from 8 to 56 (7x increase), the number of experts per token from 2 to 14 (7x increase), and decreased the moe_intermediate_size from 14336 to 2048 (7x decrease). Thus, I'm still using the same total and active number of parameters of Mixtral.

The problem is that in this new architecture with more experts, the MFU drops to 28%!

(BTW, I tried multiple configs with different tile_sizes, TPU sizes and batch sizes, all leading to 25-28% MFU)

Any help is much appreciated.

(cc'ing @sharadmv @RissyRan @lenscloth who might be interested in this problem)

rodrigo-f-nogueira avatar Feb 09 '25 11:02 rodrigo-f-nogueira

Thanks for reaching out! It seems you have tuned a little bit on this general tile size (here), but I'd like to mention this size could be very different based on TPU type and topology (sizes), and model config. So with fixed model config, you could have a script to find the best tile_sizes. When using FSDP sharding strategy (default settings), large batch sizes will definitely help improve the performance. The next step could turn on this profiling option, and see which operation (or extra communication) slow down the test. See more details here about JAX profiling.

RissyRan avatar Feb 10 '25 16:02 RissyRan

Hi @RissyRan, sorry for taking so long...

jnp.take is the slowest operation when using 56 experts (with 14 active per token and mlp_dim=2048): https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/layers/linears.py#L404

Here are the top 10 operations, in case you are curious: Image

As a reference, these are the top 10 operations when using 8 experts (with 2 active per token and mlp_dim=14336), which gives higher MFU (30% instead of 20% on a v4-64):

Image

rodrigo-f-nogueira avatar Feb 12 '25 18:02 rodrigo-f-nogueira

Thanks for the info! Yes, ideally, we should see pallas_call as top operations. Our team is working DeepSeek-like model config, and have onboarded some functional features recently. We are also working on optimizing the performance in the following few weeks. I will reply back once we have some benchmarks if this is ok.

RissyRan avatar Feb 19 '25 17:02 RissyRan

Awesome, thank you very much for your great work!

rodrigo-f-nogueira avatar Feb 19 '25 21:02 rodrigo-f-nogueira

Hello, I also got this problem, what time can be fixed?

Lisennlp avatar Jul 10 '25 07:07 Lisennlp

any updates on this?

Mddct avatar Jul 20 '25 02:07 Mddct

Thanks for reaching out! We did some internal benchmarks about DeepSeek v3 and Llama4 Maverick on Cloud v5p, using megablox, adamw, dtype=bf16, weight_dtype=f32, and FSDP sharding. The performance is around 35-40% MFU. For Llama4 Scout, we got much better performance (50%+ MFU).

Yes, so far jnp.take is the slowest operation especially in the backward (due to the slow scatter-add operation). We are still working on optimizing the performance in following directions:

  • Mixed quantization like DeepSeek v3
  • Expert parallelism optimization
  • Pipeline parallelism with megablox

Also, I would recommend you to read this model customization guide on TPU, especially you'd like to customize configs a little bit.

cc @rodrigo-f-nogueira @Lisennlp @Mddct

RissyRan avatar Jul 21 '25 20:07 RissyRan

Thanks for reaching out! We did some internal benchmarks about DeepSeek v3 and Llama4 Maverick on Cloud v5p, using megablox, adamw, dtype=bf16, weight_dtype=f32, and FSDP sharding. The performance is around 35-40% MFU. For Llama4 Scout, we got much better performance (50%+ MFU).

@RissyRan Thank you so much for sharing the progress of optimizing training speed of dMoE! Could you further provide the config used when benchmarking DeepSeek v3 models to achieve 35-40% MFU? What is the number of experts and the number of topk experts used in those experiments? I want to reproduce them but I can't find it in the model customization guide. Especially, as the number of topk experts increases, the training speed slows down significantly.

hilbertmeng avatar Jul 22 '25 06:07 hilbertmeng