vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models

Open sroy745 opened this issue 8 months ago • 6 comments

FILL IN THE PR DESCRIPTION HERE

FIX #4212

In this PR we make the following changes

  1. Update the spec_decode_worker to keep track of the sequence_ids which we were assigned bonus token ids in their last forward pass. We record this only for the MultiStepWorker since other Worker types don't utilize the KV cache for token generation. Currently we don't clear out the sequence ids from this list even on sequence termination. We need a way to get notified on sequence termination and remove those sequence ids
  2. Updated the MultiStepWorker to expand the batch during step 0. During batch expansion we check to see which of the sequence ids were assigned bonus tokens in their last forward pass. For each of those sequences we add a new sequence (with the same seq_id) without the bonus token. Once the forward pass for step 0 is completed we filter out the response to retain only those responses which correspond to the original sequences.
  3. Added a flag --disable-bonus-tokens-in-kv-cache to enable/disable bonus tokens for MultiStepWorker.

Some numbers from e2e tests. Note that the e2e tests don't use cuda graphs. The draft model is JackFram/llama-68m and the target model is JackFram/llama-160m. We use a batch size of 64. Completion time for num_speculation = 1 shows ~33% speedup

  • w/o bonus Processed prompts: 100%|█████████████████64/64 [00:06<00:00, 10.13it/s, est. speed input: 78.48 toks/s, output: 2592.40 toks/s]

  • with bonus Processed prompts: 100%|█████████████████████ 64/64 [00:04<00:00, 15.50it/s, est. speed input: 120.13 toks/s, output: 3968.28 toks/s]

sroy745 avatar Jun 22 '24 19:06 sroy745