TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

feat: Add initial EAGLE-3 implementation

Open mikeiovine opened this issue 8 months ago • 20 comments

Add an initial implementation of EAGLE3. The algorithm is not fully implemented yet. Specifically: we are predicting single extension sequences only with this MR, not the fully token trees introduced by EAGLE2 and extended by EAGLE3.

Performance

The performance is nevertheless improved. The following data is from the MT bench dataset.

Without spec decode
E2E Request Time {'min': '0.1272', 'max': '10.0149', 'mean': '3.9125', 'std': '2.5989', 'quantiles': {
'0.25': '1.5769', '0.5': '3.7717', '0.75': '6.0113'}}
TTFT Time {'min': '0.0112', 'max': '1.2804', 'mean': '0.0255', 'std': '0.0996', 'quantiles': {'0.25': 
'0.0141', '0.5': '0.0166', '0.75': '0.0207'}}
Generation Step Time {'min': '0.0076', 'max': '1.2804', 'mean': '0.0093', 'std': '0.0050', 'quantiles'
: {'0.25': '0.0090', '0.5': '0.0093', '0.75': '0.0096'}}

With EAGLE3 (4 draft tokens)
E2E Request Time {'min': '0.1232', 'max': '7.0007', 'mean': '2.5146', 'std': '1.6071', 'quantiles': {'
0.25': '1.1102', '0.5': '2.5067', '0.75': '3.7379'}}
TTFT Time {'min': '0.0114', 'max': '1.2818', 'mean': '0.0256', 'std': '0.0997', 'quantiles': {'0.25': 
'0.0137', '0.5': '0.0171', '0.75': '0.0208'}}
Generation Step Time {'min': '0.0111', 'max': '1.2818', 'mean': '0.0140', 'std': '0.0075', 'quantiles'
: {'0.25': '0.0136', '0.5': '0.0139', '0.75': '0.0142'}}

Note the reduction in mean E2E request time. I am seeing between 2-3 tokens generated per iteration on average, with some variance depending on the prompts.

This was measured with LLama 3.1 8b instruct on a single H200. Other important parts of the setup: no KV cache reuse, no CUDA graphs.

CUDA graphs are disabled because they won't work yet with EAGLE3. We need to re-enable them in a followup. The trend I see across most prompts is that (EAGLE3 + no overlap scheduler) is faster than (vanilla + overlap scheduler), but slower than (vanilla + overlap scheduler + CUDA graphs). The CUDA graphs matter a lot, so I wanted an apples to apples comparison for these benchmarks.

Design Overview

I tried to keep it as simple as possible and reuse our existing abstractions.

  • No overlap scheduler for now. This is possible to add in theory, but it will make the code pretty messy once token trees get involved. I think the best way to proceed is to finish implementing token trees + CUDA graphs, then add overlap scheduler support if the ROI is high enough. The other 2 items are much easier to do and will boost our performance by a lot more than the overlap scheduler.

  • The EAGLE3 draft model is loaded in a separate ModelEngine. Utilizing this abstraction is crucial for keeping the code simple. Prepping AttentionMetadata by hand is very error prone. The added bonus here is that we can easily enable CUDA graphs for when we invoke the draft model autoregressively because all of the logic already exists inside ModelEngine.

  • Before invoking the target model forward, I call a separate prepare_draft_tokens function. This will append draft tokens to generation requests.

  • I think this prepare_draft_tokens function can be easily reused for e.g. Medusa, draft/target spec decode, or future versions of EAGLE. My secondary goal is to avoid extensive runtime changes for adding future spec decode algorithms.

mikeiovine avatar Mar 24 '25 17:03 mikeiovine

Adding in a comment from @lfr-0531 from the MR to the old repo:

What do you think if we add a new spec_executor.py to speculative/? Then we can add a new SpecExecutor that inherits the PyExecutor, add add _prepare_draft_batch and _prepare_draft_tokens to SpecExecutor. In this way, all of the codes related to speculative decoding will be in speculative/.

I think I'd rather keep all the executor stuff in the pyexecutor/ directory to avoid circular dependencies. As to whether or not we want a separate SpecExecutor subclass: I'm also not sure about this point. We would have a lot of awkward no-op functions in the base class. Perhaps it makes more sense to introduce this when we add the overlap scheduler support (since at that point we will likely need to completely override _executor_loop_overlap).

mikeiovine avatar Mar 24 '25 21:03 mikeiovine

/bot run

mikeiovine avatar Mar 25 '25 15:03 mikeiovine

PR_Github #446 [ run ] triggered by Bot

niukuo avatar Mar 25 '25 15:03 niukuo

PR_Github #446 [ run ] completed with state FAILURE /LLM/main/L0_MergeRequest_PR pipeline #381 completed with status: 'FAILURE'

niukuo avatar Mar 25 '25 17:03 niukuo

I have decided to allow the usage of the EAGLE3 checkpoints provided by the original paper authors on HuggingFace: https://huggingface.co/yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B

We will need a few hacks on our side to automatically detect an EAGLE3 checkpoint. These hacks can mostly be avoided if we provide our own converted checkpoint format, but I don't want to commit to any particular conventions right now. Since we don't have any custom checkpoints readily available on the NV HuggingFace yet, this will improve the user experience.

Further down the line, we need to make it compatible with fine tuned eagle checkpoints from ModelOpt. It requires a bit of alignment - we can handle it after the rest of the spec decode stuff is finalized.

(For testing: I am still waiting for the EAGLE3 checkpoints to get checked into our llm weights repo for CI)

mikeiovine avatar Mar 25 '25 17:03 mikeiovine

/bot run

mikeiovine avatar Mar 25 '25 17:03 mikeiovine

PR_Github #453 [ run ] triggered by Bot

niukuo avatar Mar 25 '25 17:03 niukuo

PR_Github #453 [ run ] completed with state FAILURE /LLM/main/L0_MergeRequest_PR pipeline #388 completed with status: 'FAILURE'

niukuo avatar Mar 25 '25 19:03 niukuo

I have decided to separate the KVCacheManagers for the target and draft models. This has the following advantages:

  1. Avoids _LAYER_INDEX_OFFSET hack when creating the models
  2. Lets us support draft models with different structures/KV cache dtypes

The disadvantage is that we duplicate the request -> page ID bookkeeping overhead. The KV cache managers are kept in sync by the ResourceManager in PyExecutor. It is actually important that the draft model uses the same request IDs as the target model (this is so the draft model can locate the right page IDs)

I am sweeping estimate_kv_cache_max_num_tokens under the rug for now. The draft model usually doesn't use up too much memory. I'll make it more accurate in a follow up, I left a TODO.

mikeiovine avatar Mar 26 '25 20:03 mikeiovine

Now that the EAGLE3 checkpoints have landed in our llm-models repository, I can finally add tests to this PR. There are now tests for the following:

  1. A sanity check on quickstart_advanced.py
  2. Check that the token acceptance rate is reasonable (>25% for the prompt I picked)
  3. Check that outputs with spec decode exactly match outputs without

mikeiovine avatar Mar 26 '25 21:03 mikeiovine

/bot run

mikeiovine avatar Mar 26 '25 21:03 mikeiovine

PR_Github #613 [ run ] triggered by Bot

tensorrt-cicd avatar Mar 26 '25 21:03 tensorrt-cicd

PR_Github #613 [ run ] completed with state FAILURE /LLM/main/L0_MergeRequest_PR pipeline #517 completed with status: 'FAILURE'

tensorrt-cicd avatar Mar 26 '25 21:03 tensorrt-cicd

/bot run --disable-fail-fast

mikeiovine avatar Mar 28 '25 17:03 mikeiovine

PR_Github #681 [ run ] triggered by Bot

tensorrt-cicd avatar Mar 28 '25 17:03 tensorrt-cicd

PR_Github #681 [ run ] completed with state SUCCESS /LLM/main/L0_MergeRequest_PR pipeline #572 completed with status: 'SUCCESS'

tensorrt-cicd avatar Mar 28 '25 21:03 tensorrt-cicd

/bot skip --comment "Pipeline passed before rebase"

mikeiovine avatar Mar 29 '25 03:03 mikeiovine

PR_Github #685 [ skip ] triggered by Bot

tensorrt-cicd avatar Mar 29 '25 03:03 tensorrt-cicd

PR_Github #685 [ skip ] completed with state SUCCESS Skipping testing for commit 35f1ab0

tensorrt-cicd avatar Mar 29 '25 03:03 tensorrt-cicd

@lfr-0531 @QiJune: can either of you take a look? @hlu1 does not have repository write access yet, so the merge is blocked. Thanks

mikeiovine avatar Mar 29 '25 03:03 mikeiovine

/bot reuse-pipeline

lfr-0531 avatar Mar 29 '25 14:03 lfr-0531

PR_Github #690 [ reuse-pipeline ] triggered by Bot

tensorrt-cicd avatar Mar 29 '25 14:03 tensorrt-cicd

PR_Github #690 [ reuse-pipeline ] completed with state SUCCESS Reusing PR_Github #681 for commit df5efa0

tensorrt-cicd avatar Mar 29 '25 14:03 tensorrt-cicd