axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

Integration of fused moe kernel (e.g., megablocks) for efficient moe training

Open zinccat opened this issue 3 months ago • 8 comments

⚠️ Please check that this feature request hasn't been suggested before.

  • [x] I searched previous Ideas in Discussions didn't find any similar feature requests.
  • [x] I searched previous Issues didn't find any similar feature requests.

🔖 Feature description

The default huggingface moe implementation is inefficient due to for loop on experts, which causes low utilization when training models like Qwen3 30B-A3B, we might want a drop in kernel patch for more efficient moe calculation A previous issue on this https://github.com/axolotl-ai-cloud/axolotl/issues/930

✔️ Solution

Integrating solutions like https://github.com/databricks/megablocks/

❓ Alternatives

No response

📝 Additional Context

No response

Acknowledgements

  • [x] My issue title is concise, descriptive, and in title casing.
  • [x] I have searched the existing issues to make sure this feature has not been requested yet.
  • [x] I have provided enough information for the maintainers to understand and evaluate this request.

zinccat avatar Sep 12 '25 04:09 zinccat

Yep, this is something I was just checking. I saw that upstream transformers EP PR was merged https://github.com/huggingface/transformers/pull/39501 . It uses kernels-community/megablocks (not sure if same as databrick's one)

I just need to read more on how it can be used and apply it for qwen3 too.

NanoCode012 avatar Sep 12 '25 04:09 NanoCode012

Seems that original kernel is binded to gpt oss, I have made it work for qwen3, but it seems deepspeed is causing trouble if I try to merge the experts weight after loading, otherwise it requires merge the weight beforehand

zinccat avatar Sep 19 '25 06:09 zinccat

@zinccat , could you share how you got it working for qwen3 for reference purposes?

This is currently a WIP for us.

NanoCode012 avatar Sep 23 '25 05:09 NanoCode012

sure, I'll provide my version later

zinccat avatar Sep 23 '25 16:09 zinccat

you may reference the commits in https://github.com/zinccat/qwen3_moe_megablocks

zinccat avatar Sep 24 '25 03:09 zinccat

@zinccat , correct me if I'm wrong but is the shape for the router mixed up?

        self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size, dtype=torch.bfloat16))

Should it be:

self.weight = nn.Parameter(torch.empty(
            config.hidden_size, config.num_experts, dtype=torch.bfloat16
        ))

? Ref: https://github.com/huggingface/transformers/blob/9b4bd96e2b3e0c65bac21706dfeb2fc5ff7e3c22/src/transformers/models/llama4/modeling_llama4.py#L133

NanoCode012 avatar Sep 26 '25 08:09 NanoCode012

it's either self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size, dtype=torch.bfloat16)) or the class as a subclass ofnn.Linear(config.hidden_size, config.num_experts), as specified in your reference

zinccat avatar Sep 26 '25 15:09 zinccat

Just an FYI, I've been poking around w/ the kernel-commmunity/megablocks code myself, here's a HIP-compatible version I've been bringing up (all numerics are valid, however I get RCCL hangs on my MI300X...) https://huggingface.co/shisa-ai/megablocks-hip

lhl avatar Sep 30 '25 03:09 lhl