exo
exo copied to clipboard
Add Multi-Token Prediction (MTP) for DeepSeek V3 speculative decoding
Motivation
DeepSeek V3 includes a Multi-Token Prediction (MTP) layer (layer 61) that is currently discarded during model loading. This layer can be used for speculative decoding to improve generation throughput.
Based on vLLM/SGLang research:
- 81-82% acceptance rate with k=1 draft tokens
- 1.5-2x speedup at low QPS
- Diminishing returns at high concurrency
Changes
New Files
-
src/exo/worker/engines/mlx/mtp/__init__.py- Module exports -
src/exo/worker/engines/mlx/mtp/module.py- MTPModule class implementing the MTP architecture -
src/exo/worker/engines/mlx/mtp/speculative_decode.py- Speculative decoding logic
Modified Files
-
src/exo/worker/engines/mlx/constants.py- AddedMTP_ENABLEDandMTP_NUM_DRAFT_TOKENSconfig -
src/exo/worker/engines/mlx/utils_mlx.py- Patch sanitize() to preserve layer 61, extract MTP module -
src/exo/worker/engines/mlx/generator/generate.py- Integrate MTP generation path
Why It Works
The MTP module uses the hidden states from the main model's forward pass to predict the next token. This allows speculative decoding without a separate draft model:
- Main model generates token + hidden state
- MTP module predicts next token using hidden state + current token embedding
- Main model verifies the prediction
- If accepted, both tokens are yielded; otherwise only the verified token
The key insight is that MTP needs hidden states (not just logits), which required wrapping the model to capture intermediate outputs.
Test Plan
Manual Testing
Automated Testing
- All 151 existing tests pass
- Ruff linting passes
- Type checking has expected
reportAnyerrors from external mlx_lm library (no new structural errors)
🤖 Generated with Claude Code