LlamaMOE. The order of softmax and topK
Dear authors,
I notice that in HF's Mixtral:
https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/mixtral/modeling_mixtral.py#L852
Softmax is called before topK, and they re-normalize the probs after topK. While in LitGPT's LlamaMOE:
https://github.com/Lightning-AI/litgpt/blob/c81800f455dd997f786cbe2e110eff1f5c0d2d3b/litgpt/model.py#L341
Softmax is called after topK.
May I ask what is your thoughts behind this, and would this difference cause any mismatch in models' output?
Thank you so much!
Our implementation follows most closely the Mistral reference: https://github.com/mistralai/mistral-src/blob/main/moe_one_file_ref.py#L205-L212, please also notice our docstring reference: https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L335
I cannot comment on why HF chose to follow a different approach. There doesn't seem to be a numerical difference at least in fp32