Add MuonW optimizer: Muon with AdamW fallback for non-matrix parameters
This PR adds the MuonW optimizer to OLMo, implementing the Muon optimization algorithm with AdamW fallback for non-matrix parameters.
Key features:
- Implements Muon's Newton-Schulz orthogonalization for matrix parameters (2D+)
- Falls back to AdamW for scalar/vector parameters and embeddings/heads
- Fully compatible with distributed training (FSDP)
- Includes comprehensive metric tracking for monitoring
Implementation details:
- Based on the original Muon paper and reference implementation
- Adds distributed metric collection and reduction
- Handles conv filters through reshaping
- Supports selective weight updates and gradient clipping
Testing:
- Tested on single GPU/CPU with comprehensive test suite.
- Mock tests verify distributed code paths
- Convergence verified on regression tasks
Happy to add config integration if there's interest. Tested locally - all core functionality working.
Hi team, just wanted to gently follow up on this PR for the MuonW optimizer.
I know you're all very busy, so no rush at all. Please let me know if there are any questions, changes, or additional tests I can provide from my end to help move the review process along.
Thanks for your time and for maintaining this great project!
Hi there, thanks for your contribution and interest! We apologize for the delay in response to your PR - we are indeed at a busy time of year. We will take a look at this as soon as we can!