exo icon indicating copy to clipboard operation
exo copied to clipboard

Add Multi-Token Prediction (MTP) for DeepSeek V3 speculative decoding

Open AlexCheema opened this issue 1 month ago • 0 comments

Motivation

DeepSeek V3 includes a Multi-Token Prediction (MTP) layer (layer 61) that is currently discarded during model loading. This layer can be used for speculative decoding to improve generation throughput.

Based on vLLM/SGLang research:

  • 81-82% acceptance rate with k=1 draft tokens
  • 1.5-2x speedup at low QPS
  • Diminishing returns at high concurrency

Changes

New Files

  • src/exo/worker/engines/mlx/mtp/__init__.py - Module exports
  • src/exo/worker/engines/mlx/mtp/module.py - MTPModule class implementing the MTP architecture
  • src/exo/worker/engines/mlx/mtp/speculative_decode.py - Speculative decoding logic

Modified Files

  • src/exo/worker/engines/mlx/constants.py - Added MTP_ENABLED and MTP_NUM_DRAFT_TOKENS config
  • src/exo/worker/engines/mlx/utils_mlx.py - Patch sanitize() to preserve layer 61, extract MTP module
  • src/exo/worker/engines/mlx/generator/generate.py - Integrate MTP generation path

Why It Works

The MTP module uses the hidden states from the main model's forward pass to predict the next token. This allows speculative decoding without a separate draft model:

  1. Main model generates token + hidden state
  2. MTP module predicts next token using hidden state + current token embedding
  3. Main model verifies the prediction
  4. If accepted, both tokens are yielded; otherwise only the verified token

The key insight is that MTP needs hidden states (not just logits), which required wrapping the model to capture intermediate outputs.

Test Plan

Manual Testing

Automated Testing

  • All 151 existing tests pass
  • Ruff linting passes
  • Type checking has expected reportAny errors from external mlx_lm library (no new structural errors)

🤖 Generated with Claude Code

AlexCheema avatar Jan 18 '26 12:01 AlexCheema