speculative decoding not work
Hi Team, I'm testing speculative decoding feature with trtllm, but meet some issue. Following is my settings:
hardware: A100 80G software: nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3 model: gemma-2-2b-it / gemma-2-27b-it
cd /llm/tmp/trtllm/v0.17/TensorRT-LLM/examples/gemma/
- convert draft model
python3 convert_checkpoint.py \
--ckpt-type hf \
--model-dir /llm/tmp/model/gemma-2-2b-it/ \
--dtype bfloat16 \
--world-size 1 \
--output-model-dir gemma2_2b_ckpt \
trtllm-build \
--checkpoint_dir ./gemma2_2b_ckpt \
--output_dir ./gemma2_2b_engine \
--max_batch_size 8 \
--max_input_len 6000 \
--max_seq_len 6400 \
--cluster_key A100-SXM-80GB \
--remove_input_padding enable \
--kv_cache_type paged \
--gpt_attention_plugin bfloat16 \
--gemm_plugin bfloat16 \
--context_fmha enable \
--use_paged_context_fmha enable \
--use_fused_mlp enable \
--gather_generation_logits \
python3 ../run.py \
--engine_dir ./gemma2_2b_engine \
--tokenizer_dir /llm/tmp/model/gemma-2-2b-it/ \
--max_output_len=200 \
--input_text "The basic idea of a Transformer model is" \
test draft model, the inference is ok
2. convert target model with speculative decoding
python3 convert_checkpoint.py \
--ckpt-type hf \
--model-dir /llm/tmp/model/gemma-2-27b-it \
--dtype bfloat16 \
--world-size 1 \
--output-model-dir gemma2_27b_ckpt \
trtllm-build \
--checkpoint_dir ./gemma2_27b_ckpt \
--output_dir ./gemma2_27b_engine \
--max_batch_size 8 \
--max_input_len 6000 \
--max_seq_len 6400 \
--cluster_key A100-SXM-80GB \
--remove_input_padding enable \
--kv_cache_type paged \
--gpt_attention_plugin bfloat16 \
--gemm_plugin bfloat16 \
--context_fmha enable \
--use_paged_context_fmha enable \
--use_fused_mlp enable \
--gather_generation_logits \
--speculative_decoding_mode draft_tokens_external \
--max_draft_len=10 \
# test target model
python3 ../run.py \
--engine_dir ./gemma2_27b_engine \
--draft_engine_dir ./gemma2_2b_engine \
--tokenizer_dir /llm/tmp/model/gemma-2-2b-it/ \
--kv_cache_enable_block_reuse \
--kv_cache_free_gpu_memory_fraction=0.95 \
--max_output_len=200 \
--input_text "The basic idea of a Transformer model is" \
the generated text is strange:
If there is any issue with my script, please point it out. Many thanks~
Please use the following example to use draft-target speculative decoding with run.py: https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/draft_target_model
In your example, you are missing the draft_target_model_config argument.
Issue has not received an update in over 14 days. Adding stale label.
This issue was closed because it has been 14 days without activity since it has been marked as stale.