TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

How to use Medusa to support non llama models?

Open skyCreateXian opened this issue 1 year ago • 8 comments

System Info

Hardware: L20 Version: 0.11.0.dev20240625 Model: Bloom7b1

Who can help?

@ncomly-nvidia @byshiue I have obtained the Medusa head for Bloom according to the official Medusa documentation, but during deployment, I need to modify bloom/model.py. I referenced llama/model.py to modify a version, but the accuracy is very poor. Therefore, I have two questions

  1. Does Medusa support deploying other models that are not llama classes?
  2. For other types of model. py, please provide reference Medusa official modification tips, like '[MODIFIED]' reference resources: https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/modeling_llama_kv.py I mainly adapted the spec_decoding-params parameter in bloom/model.py

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

1、Medusa Head for Training Bloom Model 2、Adapted spec_decoding-params parameter in bloom/modl.py

Expected behavior

nothing

actual behavior

nothing

additional notes

nothing

skyCreateXian avatar Jul 15 '24 03:07 skyCreateXian

Is GatedMLP suitable for medusa decoration? I found two characteristics during debugging

  1. The only difference between the modified bloom/modl.py and llama lies in the MLP layer, where llama uses GatedMLP Bloom, on the other hand, uses MLP
  2. When the accept comes from the Medusa result, the last token accepted must not be aligned Is MLP layers not suitable for Medusa algorithm?

skyCreateXian avatar Jul 16 '24 03:07 skyCreateXian

Hi @skyCreateXian , thank you for bringing this up. Agreed. We should have a documentation on steps required for making Medusa to work with other models. I think you are on the right track. The following steps should be enough to support Medusa for other models:

  1. Adding spec_decoding_params to base model (e.g. Bloom in this case).
  2. New conversion script to combine base model and the Medusa heads into TensorRT-LLM checkpoint.
  3. Changing the medusa/model.py to use the updated base model.

To answer your question on MLP, it shouldn't have any effect on Medusa. One other difference I can think of which can lead to poor accuracy is the position embedding: RoPE vs ALiBi. With Medusa, a position offset tensor is passed to the model to properly apply the position embedding to the Medusa tokens. I am not too familiar with ALiBi yet, but if it requires more than just the position offsets, then that could be the other thing that is needed to support Medusa with Bloom.

I hope this helps. Please let us know how it goes and/or if you have any more questions.

rakib-hasan avatar Jul 16 '24 05:07 rakib-hasan

@rakib-hasan How to verify the differences caused by the position encoding algorithm? I found that forcibly modifying the "position-embeddingtotype" in convert_checkpoint: "rope_gpt_neos" did not work

skyCreateXian avatar Jul 16 '24 13:07 skyCreateXian

hello, and how to use medusa to support qwen model?it's different with llama and bloom.

sundayKK avatar Jul 20 '24 02:07 sundayKK

@sundayKK sun I adapted qwen2-7b, but found that the result was completely different from the base model, so it failed. You can follow the steps below:

  1. Adapt qwen training in Medusa to obtain training heads
  2. Modify models/medusa/modl.py to support qwen
  3. Modify models/qwen/model-py to support speculative decoding parameters

skyCreateXian avatar Jul 23 '24 11:07 skyCreateXian

@skyCreateXian Apologies for the late response. That sounds correct. Changing the position encoding at inference time won't work as the Bloom model seems to be trained with ALiBi. The problem is that, as I understand, XQA kernel supports tree attention (required by Medusa) but doesn't support ALiBi. So, at this point, Medusa with models that uses ALiBi won't work.

@sundayKK It seems qwen2 uses RoPE so it should be compatible. I do not know that architecture details yet. But is there any other differences between qwen2 and LLaMA?

rakib-hasan avatar Jul 24 '24 01:07 rakib-hasan

@skyCreateXian @rakib-hasan thanks for your answer! I'd like to try.

sundayKK avatar Jul 24 '24 02:07 sundayKK

@skyCreateXian Apologies for the late response. That sounds correct. Changing the position encoding at inference time won't work as the Bloom model seems to be trained with ALiBi. The problem is that, as I understand, XQA kernel supports tree attention (required by Medusa) but doesn't support ALiBi. So, at this point, Medusa with models that uses ALiBi won't work.

@sundayKK It seems qwen2 uses RoPE so it should be compatible. I do not know that architecture details yet. But is there any other differences between qwen2 and LLaMA?

@rakib-hasan I tested qwen2-7b and found that it cannot be aligned on this model, so I suspect that the diff is not caused by positional encoding differences, I will continue to check

skyCreateXian avatar Aug 06 '24 09:08 skyCreateXian

@skyCreateXian any updates on this?

poweiw avatar May 21 '25 22:05 poweiw

@poweiw no, thanks

skyCreateXian avatar May 22 '25 02:05 skyCreateXian

Closing for now. Feel free to reopen after the work is resumed thanks! 👍

poweiw avatar May 29 '25 19:05 poweiw