How to use Medusa to support non llama models?
System Info
Hardware: L20 Version: 0.11.0.dev20240625 Model: Bloom7b1
Who can help?
@ncomly-nvidia @byshiue I have obtained the Medusa head for Bloom according to the official Medusa documentation, but during deployment, I need to modify bloom/model.py. I referenced llama/model.py to modify a version, but the accuracy is very poor. Therefore, I have two questions
- Does Medusa support deploying other models that are not llama classes?
- For other types of model. py, please provide reference Medusa official modification tips, like '[MODIFIED]' reference resources: https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/modeling_llama_kv.py I mainly adapted the spec_decoding-params parameter in bloom/model.py
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
1、Medusa Head for Training Bloom Model 2、Adapted spec_decoding-params parameter in bloom/modl.py
Expected behavior
nothing
actual behavior
nothing
additional notes
nothing
Is GatedMLP suitable for medusa decoration? I found two characteristics during debugging
- The only difference between the modified bloom/modl.py and llama lies in the MLP layer, where llama uses GatedMLP Bloom, on the other hand, uses MLP
- When the accept comes from the Medusa result, the last token accepted must not be aligned Is MLP layers not suitable for Medusa algorithm?
Hi @skyCreateXian , thank you for bringing this up. Agreed. We should have a documentation on steps required for making Medusa to work with other models. I think you are on the right track. The following steps should be enough to support Medusa for other models:
- Adding
spec_decoding_paramsto base model (e.g. Bloom in this case). - New conversion script to combine base model and the Medusa heads into TensorRT-LLM checkpoint.
- Changing the medusa/model.py to use the updated base model.
To answer your question on MLP, it shouldn't have any effect on Medusa. One other difference I can think of which can lead to poor accuracy is the position embedding: RoPE vs ALiBi. With Medusa, a position offset tensor is passed to the model to properly apply the position embedding to the Medusa tokens. I am not too familiar with ALiBi yet, but if it requires more than just the position offsets, then that could be the other thing that is needed to support Medusa with Bloom.
I hope this helps. Please let us know how it goes and/or if you have any more questions.
@rakib-hasan How to verify the differences caused by the position encoding algorithm? I found that forcibly modifying the "position-embeddingtotype" in convert_checkpoint: "rope_gpt_neos" did not work
hello, and how to use medusa to support qwen model?it's different with llama and bloom.
@sundayKK sun I adapted qwen2-7b, but found that the result was completely different from the base model, so it failed. You can follow the steps below:
- Adapt qwen training in Medusa to obtain training heads
- Modify models/medusa/modl.py to support qwen
- Modify models/qwen/model-py to support speculative decoding parameters
@skyCreateXian Apologies for the late response. That sounds correct. Changing the position encoding at inference time won't work as the Bloom model seems to be trained with ALiBi. The problem is that, as I understand, XQA kernel supports tree attention (required by Medusa) but doesn't support ALiBi. So, at this point, Medusa with models that uses ALiBi won't work.
@sundayKK It seems qwen2 uses RoPE so it should be compatible. I do not know that architecture details yet. But is there any other differences between qwen2 and LLaMA?
@skyCreateXian @rakib-hasan thanks for your answer! I'd like to try.
@skyCreateXian Apologies for the late response. That sounds correct. Changing the position encoding at inference time won't work as the Bloom model seems to be trained with ALiBi. The problem is that, as I understand, XQA kernel supports tree attention (required by Medusa) but doesn't support ALiBi. So, at this point, Medusa with models that uses ALiBi won't work.
@sundayKK It seems qwen2 uses RoPE so it should be compatible. I do not know that architecture details yet. But is there any other differences between qwen2 and LLaMA?
@rakib-hasan I tested qwen2-7b and found that it cannot be aligned on this model, so I suspect that the diff is not caused by positional encoding differences, I will continue to check
@skyCreateXian any updates on this?
@poweiw no, thanks
Closing for now. Feel free to reopen after the work is resumed thanks! 👍