Build failure when using FP8 quantized Medusa heads
System Info
CPU architecture: x86_64 GPU: NVIDIA H100 Libraries TensorRT-LLM: v0.11.0 TensorRT: 10.1.0 Modelopt: 0.13.1 CUDA: 12.3 NVIDIA driver version: 535.129.03
Issue
Hello, I'm experiencing a failure when building a Mixtral + Medusa heads FP8 checkpoint (weights and KV Cache).
Reproduction
Steps to reproduce the behavior:
- get a Mixtral + Medusa heads TensorRT-LLM checkpoint in FP8 using the quantize.py script. The command is
quantize.py --model_dir=<MODEL DIR> --dtype=float16 --tp_size=1 --output_dir=<CHECKPOINT DIR> --qformat=fp8 --kv_cache_dtype=fp8 --calib_dataset=<CALIB DATASET> --calib_size=512 --batch_size=8 --calib_max_seq_length=1024 --num_medusa_heads=2 --num_medusa_layers=1 --max_draft_len=2 --medusa_model_dir=<MEDUSA MODEL DIR> --quant_medusa_head - build the checkpoint with command
trtllm-build --checkpoint_dir=<CHECKPOINT DIR> --max_beam_width=1 --max_seq_len=32768 --max_input_len=32368 --max_num_tokens=32768 --max_batch_size=4 --context_fmha=enable --use_custom_all_reduce=disable --output_dir=<OUT DIR> --use_fp8_context_fmha=disable --speculative_decoding_mode=medusa
Expected behavior
The expected output would be a correctly built Mixtral + Medusa heads FP8 engine.
Actual behavior
The checkpoint generation works and the quantized model is exported, but when running trtllm-build, a crash occurs with the following message:
Traceback (most recent call last):
File "/usr/local/bin/trtllm-build", line 8, in <module>
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 551, in main
parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 373, in parallel_build
passed = build_and_save(rank, rank % workers, ckpt_dir,
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 340, in build_and_save
engine = build_model(build_config,
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 309, in build_model
model = model_cls.from_checkpoint(ckpt_dir, config=rank_config)
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/modeling_utils.py", line 426, in from_checkpoint
model.load(weights, from_pruned=is_checkpoint_pruned)
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/modeling_utils.py", line 439, in load
raise RuntimeError(
RuntimeError: Required but not provided tensors:{'medusa_heads.1.lm_head.activation_scaling_factor', 'medusa_heads.1.lm_head.weights_scaling_factor', 'medusa_heads.0.lm_head.activation_scaling_factor', 'medusa_heads.1.medusa_layers.0.linear.weights_scaling_factor', 'medusa_heads.1.medusa_layers.0.linear.activation_scaling_factor', 'medusa_heads.0.medusa_layers.0.linear.weights_scaling_factor', 'medusa_heads.0.medusa_layers.0.linear.activation_scaling_factor', 'medusa_heads.0.lm_head.weights_scaling_factor'}
The same exact trtllm-build command works correctly if executed starting from a checkpoint obtained by running the same exact quantize.py command with the only difference of removing the --quant_medusa_head flag, correctly building the engine.