transformers
transformers copied to clipboard
Add Arcee model support
Summary
This PR adds support for the Arcee model architecture, laying the groundwork for the upcoming Arcee Foundation Model (AFM) release. Arcee is a decoder-only transformer model based on the Llama architecture with a key modification: it uses ReLU² (ReLU-squared) activation in the MLP blocks instead of SiLU, following recent research showing improved training efficiency with squared activations.
Model Description
Arcee is architecturally similar to Llama but with the following distinctions:
- ReLU² activation: Uses
x * relu(x)in MLP layers for improved gradient flow - Optimized for efficiency: Designed with training and inference efficiency in mind
- Extended context: Supports extended context with RoPE scaling
Implementation Details
- Modular implementation inheriting from Llama components where applicable
- Custom ArceeMLP class implementing the ReLU² activation
- Full support for all standard transformers features:
- Flash Attention 2, SDPA, and other attention backends
- Gradient checkpointing
- Quantization support (including quantized caches)
- All standard model variants (CausalLM, SequenceClassification, QuestionAnswering, TokenClassification)
Testing
- Added comprehensive test suite following standard transformers test patterns
- Tests for all model variants and core functionality
- Specific test for ReLU² activation verification
- RoPE scaling tests including YARN support
- Tested model forward and backward passes
- Verified compatibility with existing architecture
- Model loading and forward passes verified
- Compatibility with existing infrastructure confirmed