text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Add support for Speculative Decoding

Open OliverFM opened this issue 1 year ago • 7 comments

Feature request

There is a new and interesting paper from Google Research that promising 2-3X speedups of LLM inference by running two models in parallel. The core idea is using a faster, and lower quality model, that approximates the target model to sample multiple tokens and then check these samples using the target model. E.g. Sample from LLaMA 7b quickly, then use LLaMA 70b to check the samples.

Motivation

Adding this kind of support would make LLM sampling much faster.

I have considered running an alternative implementation where I run two copies of TGI and a new web server implementing Speculative Decoding in the same kubernetes pod/vm/server. However the added overhead of running HTTP between all these containers is likely to erase a significant portion of the gains in inference speed.

Core challenges:

  1. Adding this feature would require making TGI more generic, so that one can run multiple models at once. We would need to make sure that this does not degrade performance or reliability for the single model use case.
  2. Running many models on a container would make GPU selection more tricky, but, again, this should not be insurmountable.
  3. We would need to add some new options to the public API, this will require careful thought.

Your contribution

Presuming the maintainers are happy with adding this feature, I would start work and implement it. This would probably take the form of several PRs, as this change would be signficant.

OliverFM avatar Jul 28 '23 15:07 OliverFM

Yes that's something that we want to explore.

OlivierDehaene avatar Jul 28 '23 15:07 OlivierDehaene

Yes that's something that we want to explore.

What would be the best way to get this conversation moving? I have started looking at implementing this myself, but am not yet sure how this would best fit together.

For performance reasons it would be ideal if we could keep this step in the rust code. I also suspect that it would be best placed somewhere in the router code. Since we would probably want new endpoints generate_speculative and generate_speculative_stream, that rely on the existing routing code, but add in the logic to aggregate the calls together.

Does this sound like a good idea?

OliverFM avatar Jul 28 '23 16:07 OliverFM

Hey thanks for proposing to contribute.

Disclaimer: It's august so a lot of the team members are off to vacation and I myself am handling various (too many perhaps) projects. So dev bandwidth is rather low right now for adding this.

Brainstorm DX:

As a user I think I'd like to simply say:

text-generation-launcher --model-id BIG_MODEL --regular-params --speculative-model SMALL_MODEL_ID --speculative-steps 4

I think it's relatively simple and conveys correctly the intent. We need to think about TP sharding (do we want to shard the small model just like the big one, or can we expect it to fit a single GPU all the time, if we fit into a single GPU, that GPU will have less free RAM than others but inference on it is likely to be faster than if split, is that OK when sending massive throughput ?).

If we need more complex control we'd have to break from the usual pure cli flags, and use some kind of config instead.

Implentation wise:

As usual, the smaller the PR, the better it is. We should focus on getting 1 functioning case, not focus on all the options we could add.

  • We really need to make the sampling from distributions efficient. It can become a bottleneck extremely fast.
  • We could have 2 grpc backends 1 for each model, the router could control both like currently.
    • We probably need to modify the Batch object in order to optionally send those speculative-steps logits instead of tokens
    • We need to send those logits to the larger models
    • We need to add this new LogitsProcessor.
  • We need to add a prometheus metric tracking how many tokens were successfully skipped (that's the main interest of the method we need to know how effective it is in real runtimes.)

I think keeping the exact same routes /generate and /generate_stream are more than OK. IMHO, speculative decoding is purely an optimization trick, not something users of the model/API care about.

Narsil avatar Aug 02 '23 17:08 Narsil

This is great.

But I think there might be more performance gains elsewhere before getting to speculative decoding.

In many cases I have seen for LLM larger than 3B, the CPU is bottlenecking the text generation. https://github.com/huggingface/transformers/issues/24524

Introducing more layers without resolving the underlying bottleneck might not be fruitful.

calvintwr avatar Aug 21 '23 23:08 calvintwr

@calvintwr we're not using transformers here. CPU bottlenecking is really something, but we have a longer term solution for it https://github.com/huggingface/candle/

Narsil avatar Aug 22 '23 09:08 Narsil

@calvintwr, yes this CPU bottleneck is the reason we often re-write the modelling code in TGI.

Speculative decoding is our main priority for the next release.

OlivierDehaene avatar Sep 06 '23 13:09 OlivierDehaene

I don't have the expertise to comment, but just in case: If this https://twitter.com/tianle_cai/status/1701272996328120408 is good enough that it is a viable alternative to the "separate draft model" approach to speculative decoding, then perhaps it could inform some design decisions for the eventual form that this API takes (even just naming of parameters like --speculative-model).


Edit: Actually, after more than a 2 second skim of that twitter thread, I think this is "baked into" the model/network itself? So maybe use of this technique can be opaque to TGI? (I probably shouldn't be commenting here - just adding newbie noise)

Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned. During generation, these heads each produce multiple likely words for the corresponding position. These options are then combined and processed using a tree-based attention mechanism. Finally, a typical acceptance scheme is employed to pick the longest plausible prefix from the candidates for further decoding.

https://github.com/FasterDecoding/Medusa

josephrocca avatar Sep 11 '23 17:09 josephrocca

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Apr 18 '24 01:04 github-actions[bot]