candle
candle copied to clipboard
Add SmolLM3: Full and Quantized Implementation
Summary
This PR adds comprehensive support for SmolLM3-3B with both full precision (safetensors) and quantized (GGUF) implementations, unified under a single example interface.
What's New
Model Implementation
- Full precision model (
models/smol/smollm3.rs): Native safetensors support with F32/F16/BF16 - Quantized model (
models/smol/quantized_smollm3.rs): GGUF support with Q4_K_M, Q8_0, and F16 quantization - Unified example (
examples/smollm3/main.rs): Single CLI that supports both model types seamlessly
SmolLM3 Architecture Features
- Hybrid RoPE/NoPE: 3:1 ratio with every 4th layer using No Positional Encoding
- Grouped Query Attention: 32 attention heads with 8 KV heads (4 groups)
- High RoPE theta: 5,000,000 (vs typical 10k-500k)
- Long context support: Up to 128k tokens
- Thinking mode: Support for explicit reasoning with
<think>tags
Verification
Output correctness verified against reference implementations:
- Full precision: Validated against HuggingFace Transformers Python implementation
- Quantized: Validated against llama.cpp (HuggingFace Transformers doesn't yet support quantized SmolLM3)
Performance
Tested on CPU and GPU with identical prompts (9 tokens generated):
| Model Type | Device | Speed (tokens/s) | Speedup |
|---|---|---|---|
| Q8_0 | CPU | 7.31 | 1.0x |
| Q8_0 | GPU | 45.84 | 6.3x |
| Full F16 | CPU | 2.54 | 1.0x |
| Full F16 | GPU | 32.22 | 12.7x |
Technical Details
Quantized Weight Reconstruction
The quantized implementation includes special handling for Q/K weight deinterleaving to maintain compatibility with GGUF format's interleaved storage pattern. The reconstruct_qk_weights() function properly reorganizes the attention weights.
Future Work: Add optimized kernels for CPU thread utilization similar to llama.cpp's implementation.
KV-Cache Optimization Opportunity
The current implementation uses .contiguous() calls when appending to KV cache:
// Can remove this contiguous call if using ConcatKV-Cache
// See: https://github.com/huggingface/candle/pull/3143
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
The ConcatKV-Cache implementation (#3143) offers significant performance improvements:
- GPU: Multiple orders of magnitude faster
- CPU/WASM: Equivalent performance with cleaner code
Action Item: I will open a separate issue to discuss adopting ConcatKV-Cache as the default KV-cache implementation across all transformer models in Candle. This would enable DRY practices and better performance by default.
Code Organization
This PR introduces an improved organizational pattern that should be considered for future transformer implementations:
Unified Module Structure
models/smol/
├── mod.rs # Module documentation and exports
├── smollm3.rs # Full precision implementation
├── quantized_smollm3.rs # Quantized implementation
└── README.md # Family documentation
Single Example for Multiple Model Types
The examples/smollm3/main.rs demonstrates a unified approach:
- Single enum
SmolLM3Modelwrapping both implementations - Unified
ModelConfigabstraction for consistent access - Shared generation logic regardless of model type
- Simple
--model-typeflag switches between full and quantized
Benefits:
- User Experience: One example to learn, consistent CLI across model types
- Maintainability: Shared logic reduces duplication
- Testing: Single test harness validates both implementations
- Documentation: Easier to explain trade-offs between model types
This pattern could be adopted for other model families (e.g., Llama, Mistral) to provide a more cohesive user experience.
Example Usage
# Quantized model (fast, smaller memory)
cargo run --release --example smollm3 -- \
--model-type quantized \
--quantization q8_0 \
--prompt "Explain Rust's ownership system"
# Full precision model (highest quality)
cargo run --release --example smollm3 -- \
--model-type full \
--dtype f16 \
--prompt "Explain Rust's ownership system"
# Enable thinking mode for reasoning tasks
cargo run --release --example smollm3 -- \
--thinking \
--prompt "Solve this logic puzzle step by step"
Testing
- Builds successfully on CPU and GPU configurations
- Quantized model (Q8_0) outputs match llama.cpp reference
- Full model outputs match HuggingFace Transformers
- KV-cache correctly maintains state across generation
- NoPE layers properly skip positional encoding per config
- Thinking mode formats prompts correctly
Files Changed
New Files:
candle-transformers/src/models/smol/mod.rscandle-transformers/src/models/smol/smollm3.rscandle-transformers/src/models/smol/quantized_smollm3.rscandle-transformers/src/models/smol/README.mdcandle-examples/examples/smollm3/main.rscandle-examples/examples/smollm3/README.md
Modified Files:
candle-transformers/src/models/mod.rs(addedpub mod smol;)candle-examples/Cargo.toml(addedchrono = "0.4")
Related Issues
- Issue to be created: Adopt ConcatKV-Cache (#3143) as default for all transformers
- Model Card Details: https://huggingface.co/HuggingFaceTB/SmolLM3-3B
- SmolLM Blog Series: https://huggingface.co/blog/smollm and https://huggingface.co/blog/smollm3
- Related to NoPE paper: https://arxiv.org/abs/2410.01926
Checklist
- [x] Code follows Candle style guidelines
- [x] Verified outputs against reference implementations
- [x] Documentation added (README, rustdoc comments)
- [x] Example demonstrates both quantized and full precision usage
- [x] Tested on CPU and GPU
- [x] No compiler warnings
- [x] Proper error handling throughout