Integration of fused moe kernel (e.g., megablocks) for efficient moe training
⚠️ Please check that this feature request hasn't been suggested before.
- [x] I searched previous Ideas in Discussions didn't find any similar feature requests.
- [x] I searched previous Issues didn't find any similar feature requests.
🔖 Feature description
The default huggingface moe implementation is inefficient due to for loop on experts, which causes low utilization when training models like Qwen3 30B-A3B, we might want a drop in kernel patch for more efficient moe calculation A previous issue on this https://github.com/axolotl-ai-cloud/axolotl/issues/930
✔️ Solution
Integrating solutions like https://github.com/databricks/megablocks/
❓ Alternatives
No response
📝 Additional Context
No response
Acknowledgements
- [x] My issue title is concise, descriptive, and in title casing.
- [x] I have searched the existing issues to make sure this feature has not been requested yet.
- [x] I have provided enough information for the maintainers to understand and evaluate this request.
Yep, this is something I was just checking. I saw that upstream transformers EP PR was merged https://github.com/huggingface/transformers/pull/39501 . It uses kernels-community/megablocks (not sure if same as databrick's one)
I just need to read more on how it can be used and apply it for qwen3 too.
Seems that original kernel is binded to gpt oss, I have made it work for qwen3, but it seems deepspeed is causing trouble if I try to merge the experts weight after loading, otherwise it requires merge the weight beforehand
@zinccat , could you share how you got it working for qwen3 for reference purposes?
This is currently a WIP for us.
sure, I'll provide my version later
you may reference the commits in https://github.com/zinccat/qwen3_moe_megablocks
@zinccat , correct me if I'm wrong but is the shape for the router mixed up?
self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size, dtype=torch.bfloat16))
Should it be:
self.weight = nn.Parameter(torch.empty(
config.hidden_size, config.num_experts, dtype=torch.bfloat16
))
? Ref: https://github.com/huggingface/transformers/blob/9b4bd96e2b3e0c65bac21706dfeb2fc5ff7e3c22/src/transformers/models/llama4/modeling_llama4.py#L133
it's either self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size, dtype=torch.bfloat16)) or the class as a subclass ofnn.Linear(config.hidden_size, config.num_experts), as specified in your reference
Just an FYI, I've been poking around w/ the kernel-commmunity/megablocks code myself, here's a HIP-compatible version I've been bringing up (all numerics are valid, however I get RCCL hangs on my MI300X...) https://huggingface.co/shisa-ai/megablocks-hip