mlx
mlx copied to clipboard
Feature Request: Add AdaBelief Optimizer
Summary
I would like to propose adding the AdaBelief optimizer to MLX's optimizer suite. AdaBelief is a modern adaptive optimizer that has shown superior performance compared to Adam across various machine learning tasks.
Why AdaBelief?
Proven Performance
- Consistently outperforms Adam on image classification, language modeling, and machine translation tasks
- Achieves faster convergence and better generalization in most scenarios
- More stable training dynamics, especially in later stages of training
Research Impact & Adoption
- Published at NeurIPS 2020 by researchers from Yale University
- 400+ citations since publication, showing strong research community adoption
- Already implemented in major frameworks (PyTorch, TensorFlow, JAX)
- Widely used by practitioners and researchers
Technical Advantages
- Same computational cost as Adam (no additional overhead)
- Same memory requirements as Adam
- Drop-in replacement for Adam in existing workflows
- Numerically stable with proper epsilon handling
MLX-Specific Benefits
- Perfect fit for Apple Silicon optimization due to similar computational pattern to Adam
- Can leverage MLX's unified memory model efficiently
- Enhances MLX's appeal to the research community
- Aligns with MLX's goal of providing modern ML tools
References
- Original Paper: https://arxiv.org/abs/2010.07468
- Official Implementation: https://github.com/juntang-zhuang/Adabelief-Optimizer
- PyTorch Documentation: Available in torch.optim
I'm excited to contribute to MLX and help expand its optimizer ecosystem. Looking forward to your feedback and guidance on how to proceed.
Hi everyone,
I’d love to work on adding the AdaBelief optimizer to mlx.optimizers as proposed here.
Planned implementation
- API similar to
Adam/AdamW(learning_rate,betas,eps,weight_decay, optionalbias_correction) - Core update rule from the NeurIPS 2020 paper (EMA of
(g - m)^2for the variance term) - Same computational/memory cost as Adam
Questions before starting
- Should we keep
eps=1e-8(matching Adam) or use the paper’s1e-16for numerical stability? - The paper enables bias correction (
bias_correction=True), while MLX’s Adam currently defaults toFalse. Which behavior should I follow for consistency? - Should decoupled weight decay (
AdamW style) be part of the first version?
Planned additions
- Unit tests under
python/tests/test_optimizers.py - Docstring + small usage example in Optimizers docs
If this scope looks good, I can start on the implementation right away.
Thanks for reviewing, and looking forward to your feedback!