mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Adding support for the Muon Optimizer

Open Goekdeniz-Guelmez opened this issue 9 months ago • 18 comments

Proposed changes

First contribution to the MLX repo. Add the Muon optimizer to MLX's optimizer suite. Muon (MomentUm Orthogonalized by Newton-schulz) is a novel optimizer that combines momentum-based SGD with orthogonalization of parameter updates via Newton-Schulz iterations. This optimizer has shown promising results for training neural networks, particularly for convolutional and transformer architectures. The implementation follows the approach described in https://kellerjordan.github.io/posts/muon/ , adapting it to MLX's framework. The optimizer performs standard SGD-momentum updates, followed by an orthogonalization step that replaces each 2D parameter's update with the nearest orthogonal matrix using an efficient Newton-Schulz iteration. Key features of this implementation:

  • Support for standard optimizer features (learning rate, momentum, weight decay, Nesterov)
  • Efficient Newton-Schulz orthogonalization that works with bfloat16
  • Special handling for parameters of different dimensions
  • Appropriate scaling for non-square matrices

Checklist

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

Goekdeniz-Guelmez avatar Feb 28 '25 22:02 Goekdeniz-Guelmez

Used a basic 2 layer MLPwith a dummy dataset:

Training with Muon optimizer...
Epoch 1/3, Batch 0/78, Loss: 2.3055
Epoch 1/3, Batch 10/78, Loss: 2.3316
Epoch 1/3, Batch 20/78, Loss: 2.3054
Epoch 1/3, Batch 30/78, Loss: 2.3088
Epoch 1/3, Batch 40/78, Loss: 2.2991
Epoch 1/3, Batch 50/78, Loss: 2.2983
Epoch 1/3, Batch 60/78, Loss: 2.3218
Epoch 1/3, Batch 70/78, Loss: 2.2940
Epoch 1/3 completed in 0.58s - Train Loss: 2.3070, Val Loss: 2.3065, Val Acc: 0.0917
Epoch 2/3, Batch 0/78, Loss: 2.2920
Epoch 2/3, Batch 10/78, Loss: 2.2957
Epoch 2/3, Batch 20/78, Loss: 2.2779
Epoch 2/3, Batch 30/78, Loss: 2.2840
Epoch 2/3, Batch 40/78, Loss: 2.2765
Epoch 2/3, Batch 50/78, Loss: 2.2732
Epoch 2/3, Batch 60/78, Loss: 2.2992
Epoch 2/3, Batch 70/78, Loss: 2.2644
Epoch 2/3 completed in 0.06s - Train Loss: 2.2824, Val Loss: 2.3078, Val Acc: 0.0990
Epoch 3/3, Batch 0/78, Loss: 2.2648
Epoch 3/3, Batch 10/78, Loss: 2.2709
Epoch 3/3, Batch 20/78, Loss: 2.2467
Epoch 3/3, Batch 30/78, Loss: 2.2562
Epoch 3/3, Batch 40/78, Loss: 2.2492
Epoch 3/3, Batch 50/78, Loss: 2.2410
Epoch 3/3, Batch 60/78, Loss: 2.2740
Epoch 3/3, Batch 70/78, Loss: 2.2283
Epoch 3/3 completed in 0.06s - Train Loss: 2.2541, Val Loss: 2.3098, Val Acc: 0.0969

Training with standard SGD optimizer for comparison...
Epoch 1/3, Batch 0/78, Loss: 2.3028
Epoch 1/3, Batch 10/78, Loss: 2.3079
Epoch 1/3, Batch 20/78, Loss: 2.3219
Epoch 1/3, Batch 30/78, Loss: 2.3094
Epoch 1/3, Batch 40/78, Loss: 2.3017
Epoch 1/3, Batch 50/78, Loss: 2.3161
Epoch 1/3, Batch 60/78, Loss: 2.3095
Epoch 1/3, Batch 70/78, Loss: 2.2873
Epoch 1/3 completed in 0.03s - Train Loss: 2.3081, Val Loss: 2.3074, Val Acc: 0.0969
Epoch 2/3, Batch 0/78, Loss: 2.2914
Epoch 2/3, Batch 10/78, Loss: 2.2927
Epoch 2/3, Batch 20/78, Loss: 2.3017
Epoch 2/3, Batch 30/78, Loss: 2.2921
Epoch 2/3, Batch 40/78, Loss: 2.2866
Epoch 2/3, Batch 50/78, Loss: 2.3123
Epoch 2/3, Batch 60/78, Loss: 2.3063
Epoch 2/3, Batch 70/78, Loss: 2.2799
Epoch 2/3 completed in 0.03s - Train Loss: 2.2974, Val Loss: 2.3079, Val Acc: 0.0896
Epoch 3/3, Batch 0/78, Loss: 2.2833
Epoch 3/3, Batch 10/78, Loss: 2.2801
Epoch 3/3, Batch 20/78, Loss: 2.2918
Epoch 3/3, Batch 30/78, Loss: 2.2843
Epoch 3/3, Batch 40/78, Loss: 2.2869
Epoch 3/3, Batch 50/78, Loss: 2.2981
Epoch 3/3, Batch 60/78, Loss: 2.2954
Epoch 3/3, Batch 70/78, Loss: 2.2703
Epoch 3/3 completed in 0.03s - Train Loss: 2.2884, Val Loss: 2.3082, Val Acc: 0.1177

more trainings wil come!

Goekdeniz-Guelmez avatar Feb 28 '25 22:02 Goekdeniz-Guelmez

LLM SFT Finetuning

python -m mlx_lm.lora \
--model Qwen/Qwen2.5-1.5B-Instruct \
--train \
--data /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/data_smoll \ <- samantha
--fine-tune-type dora \
--num-layers 4 \
--batch-size 1 \
--iters 100 \
--val-batches 1 \
--steps-per-report 1 \
--steps-per-eval 50 \
--adapter-path /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/test_Muon \
--save-every 10 \
--max-seq-length 4096 \
--grad-checkpoint

Muon

Iter 1: Val loss 1.638, Val took 1.710s
Iter 1: Train loss 1.466, Learning Rate 1.000e-05, It/sec 0.958, Tokens/sec 619.948, Trained Tokens 647, Peak mem 5.860 GB
Iter 2: Train loss 1.189, Learning Rate 1.000e-05, It/sec 0.173, Tokens/sec 564.635, Trained Tokens 3902, Peak mem 12.744 GB
Iter 3: Train loss 1.399, Learning Rate 1.000e-05, It/sec 0.204, Tokens/sec 356.355, Trained Tokens 5652, Peak mem 12.744 GB
Iter 4: Train loss 2.229, Learning Rate 1.000e-05, It/sec 0.933, Tokens/sec 323.692, Trained Tokens 5999, Peak mem 12.744 GB
Iter 5: Train loss 1.512, Learning Rate 1.000e-05, It/sec 0.340, Tokens/sec 621.293, Trained Tokens 7825, Peak mem 12.744 GB
Iter 6: Train loss 1.446, Learning Rate 1.000e-05, It/sec 0.582, Tokens/sec 617.335, Trained Tokens 8886, Peak mem 12.744 GB
Iter 7: Train loss 1.597, Learning Rate 1.000e-05, It/sec 0.823, Tokens/sec 627.471, Trained Tokens 9648, Peak mem 12.744 GB
Iter 8: Train loss 1.793, Learning Rate 1.000e-05, It/sec 0.558, Tokens/sec 523.831, Trained Tokens 10587, Peak mem 12.744 GB
Iter 9: Train loss 1.763, Learning Rate 1.000e-05, It/sec 0.646, Tokens/sec 658.091, Trained Tokens 11605, Peak mem 12.744 GB
Iter 10: Train loss 1.022, Learning Rate 1.000e-05, It/sec 0.367, Tokens/sec 520.598, Trained Tokens 13025, Peak mem 12.744 GB
Iter 10: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000010_adapters.safetensors.
Iter 11: Train loss 1.156, Learning Rate 1.000e-05, It/sec 0.469, Tokens/sec 619.401, Trained Tokens 14346, Peak mem 12.744 GB
Iter 12: Train loss 1.381, Learning Rate 1.000e-05, It/sec 0.494, Tokens/sec 614.444, Trained Tokens 15590, Peak mem 12.744 GB
Iter 13: Train loss 1.725, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 578.744, Trained Tokens 17956, Peak mem 12.744 GB
Iter 14: Train loss 1.447, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 556.406, Trained Tokens 20224, Peak mem 12.744 GB
Iter 15: Train loss 1.443, Learning Rate 1.000e-05, It/sec 0.295, Tokens/sec 581.241, Trained Tokens 22194, Peak mem 12.744 GB
Iter 16: Train loss 1.397, Learning Rate 1.000e-05, It/sec 0.382, Tokens/sec 610.621, Trained Tokens 23793, Peak mem 12.744 GB
Iter 17: Train loss 1.550, Learning Rate 1.000e-05, It/sec 0.551, Tokens/sec 602.464, Trained Tokens 24886, Peak mem 12.744 GB
Iter 18: Train loss 0.884, Learning Rate 1.000e-05, It/sec 0.574, Tokens/sec 611.241, Trained Tokens 25951, Peak mem 12.744 GB
Iter 19: Train loss 1.424, Learning Rate 1.000e-05, It/sec 0.254, Tokens/sec 341.668, Trained Tokens 27296, Peak mem 12.744 GB
Iter 20: Train loss 1.713, Learning Rate 1.000e-05, It/sec 0.422, Tokens/sec 606.567, Trained Tokens 28735, Peak mem 12.744 GB
Iter 20: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000020_adapters.safetensors.
Iter 21: Train loss 1.322, Learning Rate 1.000e-05, It/sec 0.844, Tokens/sec 615.596, Trained Tokens 29464, Peak mem 12.744 GB
Iter 22: Train loss 1.685, Learning Rate 1.000e-05, It/sec 0.490, Tokens/sec 597.661, Trained Tokens 30683, Peak mem 12.744 GB
Iter 23: Train loss 1.298, Learning Rate 1.000e-05, It/sec 0.380, Tokens/sec 572.848, Trained Tokens 32190, Peak mem 12.744 GB
Iter 24: Train loss 1.707, Learning Rate 1.000e-05, It/sec 0.769, Tokens/sec 598.655, Trained Tokens 32968, Peak mem 12.744 GB
Iter 25: Train loss 1.942, Learning Rate 1.000e-05, It/sec 1.001, Tokens/sec 618.561, Trained Tokens 33586, Peak mem 12.744 GB
Iter 26: Train loss 1.394, Learning Rate 1.000e-05, It/sec 0.232, Tokens/sec 438.454, Trained Tokens 35475, Peak mem 12.744 GB
Iter 27: Train loss 0.959, Learning Rate 1.000e-05, It/sec 0.698, Tokens/sec 606.886, Trained Tokens 36344, Peak mem 12.744 GB
Iter 28: Train loss 1.813, Learning Rate 1.000e-05, It/sec 1.252, Tokens/sec 629.799, Trained Tokens 36847, Peak mem 12.744 GB
Iter 29: Train loss 1.326, Learning Rate 1.000e-05, It/sec 0.453, Tokens/sec 582.481, Trained Tokens 38134, Peak mem 12.744 GB
Iter 30: Train loss 1.428, Learning Rate 1.000e-05, It/sec 0.336, Tokens/sec 585.171, Trained Tokens 39877, Peak mem 12.744 GB
Iter 30: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000030_adapters.safetensors.
Iter 31: Train loss 1.194, Learning Rate 1.000e-05, It/sec 0.268, Tokens/sec 576.286, Trained Tokens 42028, Peak mem 12.744 GB
Iter 32: Train loss 1.635, Learning Rate 1.000e-05, It/sec 0.487, Tokens/sec 611.771, Trained Tokens 43283, Peak mem 12.744 GB
Iter 33: Train loss 1.634, Learning Rate 1.000e-05, It/sec 0.488, Tokens/sec 610.358, Trained Tokens 44534, Peak mem 12.744 GB
Iter 34: Train loss 1.293, Learning Rate 1.000e-05, It/sec 0.841, Tokens/sec 606.378, Trained Tokens 45255, Peak mem 12.744 GB
Iter 35: Train loss 1.499, Learning Rate 1.000e-05, It/sec 0.458, Tokens/sec 611.155, Trained Tokens 46588, Peak mem 12.744 GB
Iter 36: Train loss 1.660, Learning Rate 1.000e-05, It/sec 0.776, Tokens/sec 607.610, Trained Tokens 47371, Peak mem 12.744 GB
Iter 37: Train loss 1.599, Learning Rate 1.000e-05, It/sec 0.780, Tokens/sec 600.431, Trained Tokens 48141, Peak mem 12.744 GB
Iter 38: Train loss 1.995, Learning Rate 1.000e-05, It/sec 0.657, Tokens/sec 621.528, Trained Tokens 49087, Peak mem 12.744 GB
Iter 39: Train loss 1.799, Learning Rate 1.000e-05, It/sec 0.617, Tokens/sec 608.044, Trained Tokens 50072, Peak mem 12.744 GB
Iter 40: Train loss 1.822, Learning Rate 1.000e-05, It/sec 0.583, Tokens/sec 603.529, Trained Tokens 51108, Peak mem 12.744 GB
Iter 40: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000040_adapters.safetensors.
Iter 41: Train loss 1.490, Learning Rate 1.000e-05, It/sec 1.013, Tokens/sec 602.438, Trained Tokens 51703, Peak mem 12.744 GB
Iter 42: Train loss 1.045, Learning Rate 1.000e-05, It/sec 0.542, Tokens/sec 608.561, Trained Tokens 52826, Peak mem 12.744 GB
Iter 43: Train loss 1.379, Learning Rate 1.000e-05, It/sec 0.416, Tokens/sec 597.864, Trained Tokens 54264, Peak mem 12.744 GB
Iter 44: Train loss 1.652, Learning Rate 1.000e-05, It/sec 0.625, Tokens/sec 601.502, Trained Tokens 55227, Peak mem 12.744 GB
Iter 45: Train loss 1.688, Learning Rate 1.000e-05, It/sec 0.910, Tokens/sec 630.532, Trained Tokens 55920, Peak mem 12.744 GB
Iter 46: Train loss 1.800, Learning Rate 1.000e-05, It/sec 0.769, Tokens/sec 620.766, Trained Tokens 56727, Peak mem 12.744 GB
Iter 47: Train loss 1.776, Learning Rate 1.000e-05, It/sec 0.699, Tokens/sec 620.070, Trained Tokens 57614, Peak mem 12.744 GB
Iter 48: Train loss 1.485, Learning Rate 1.000e-05, It/sec 0.656, Tokens/sec 627.084, Trained Tokens 58570, Peak mem 12.744 GB
Iter 49: Train loss 1.485, Learning Rate 1.000e-05, It/sec 0.253, Tokens/sec 570.277, Trained Tokens 60827, Peak mem 12.744 GB
Iter 50: Val loss 1.750, Val took 1.551s
Iter 50: Train loss 1.714, Learning Rate 1.000e-05, It/sec 0.512, Tokens/sec 617.394, Trained Tokens 62032, Peak mem 12.744 GB
Iter 50: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000050_adapters.safetensors.
Iter 51: Train loss 1.254, Learning Rate 1.000e-05, It/sec 0.657, Tokens/sec 619.473, Trained Tokens 62975, Peak mem 12.744 GB
Iter 52: Train loss 1.144, Learning Rate 1.000e-05, It/sec 0.438, Tokens/sec 600.079, Trained Tokens 64344, Peak mem 12.744 GB
Iter 53: Train loss 1.054, Learning Rate 1.000e-05, It/sec 0.624, Tokens/sec 599.977, Trained Tokens 65306, Peak mem 12.744 GB
Iter 54: Train loss 1.468, Learning Rate 1.000e-05, It/sec 0.465, Tokens/sec 600.392, Trained Tokens 66598, Peak mem 12.744 GB
Iter 55: Train loss 1.737, Learning Rate 1.000e-05, It/sec 0.706, Tokens/sec 623.741, Trained Tokens 67482, Peak mem 12.744 GB
Iter 56: Train loss 1.339, Learning Rate 1.000e-05, It/sec 0.542, Tokens/sec 616.532, Trained Tokens 68619, Peak mem 12.744 GB
Iter 57: Train loss 1.531, Learning Rate 1.000e-05, It/sec 0.295, Tokens/sec 581.258, Trained Tokens 70588, Peak mem 12.744 GB
Iter 58: Train loss 1.428, Learning Rate 1.000e-05, It/sec 0.462, Tokens/sec 592.028, Trained Tokens 71869, Peak mem 12.744 GB
Iter 59: Train loss 1.068, Learning Rate 1.000e-05, It/sec 0.396, Tokens/sec 593.931, Trained Tokens 73369, Peak mem 12.744 GB
Iter 60: Train loss 1.690, Learning Rate 1.000e-05, It/sec 1.116, Tokens/sec 617.204, Trained Tokens 73922, Peak mem 12.744 GB
Iter 60: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000060_adapters.safetensors.
Iter 61: Train loss 1.314, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 568.104, Trained Tokens 76239, Peak mem 12.744 GB
Iter 62: Train loss 1.633, Learning Rate 1.000e-05, It/sec 0.415, Tokens/sec 362.680, Trained Tokens 77112, Peak mem 12.744 GB
Iter 63: Train loss 1.392, Learning Rate 1.000e-05, It/sec 0.322, Tokens/sec 581.527, Trained Tokens 78919, Peak mem 12.744 GB
Iter 64: Train loss 2.108, Learning Rate 1.000e-05, It/sec 1.422, Tokens/sec 605.856, Trained Tokens 79345, Peak mem 12.744 GB
Iter 65: Train loss 1.301, Learning Rate 1.000e-05, It/sec 0.381, Tokens/sec 594.039, Trained Tokens 80904, Peak mem 12.744 GB
Iter 66: Train loss 1.689, Learning Rate 1.000e-05, It/sec 1.009, Tokens/sec 592.073, Trained Tokens 81491, Peak mem 12.744 GB
Iter 67: Train loss 1.520, Learning Rate 1.000e-05, It/sec 0.394, Tokens/sec 596.168, Trained Tokens 83006, Peak mem 12.744 GB
Iter 68: Train loss 1.473, Learning Rate 1.000e-05, It/sec 0.320, Tokens/sec 581.864, Trained Tokens 84824, Peak mem 12.744 GB
Iter 69: Train loss 1.679, Learning Rate 1.000e-05, It/sec 0.461, Tokens/sec 583.405, Trained Tokens 86090, Peak mem 12.744 GB
Iter 70: Train loss 1.591, Learning Rate 1.000e-05, It/sec 0.608, Tokens/sec 597.838, Trained Tokens 87073, Peak mem 12.744 GB
Iter 70: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000070_adapters.safetensors.
Iter 71: Train loss 1.537, Learning Rate 1.000e-05, It/sec 0.300, Tokens/sec 577.528, Trained Tokens 88996, Peak mem 12.744 GB
Iter 72: Train loss 1.283, Learning Rate 1.000e-05, It/sec 0.383, Tokens/sec 599.126, Trained Tokens 90562, Peak mem 12.744 GB
Iter 73: Train loss 2.318, Learning Rate 1.000e-05, It/sec 1.264, Tokens/sec 578.981, Trained Tokens 91020, Peak mem 12.744 GB
Iter 74: Train loss 1.272, Learning Rate 1.000e-05, It/sec 0.375, Tokens/sec 426.068, Trained Tokens 92157, Peak mem 12.744 GB
Iter 75: Train loss 1.732, Learning Rate 1.000e-05, It/sec 0.362, Tokens/sec 588.029, Trained Tokens 93783, Peak mem 12.744 GB
Iter 76: Train loss 1.517, Learning Rate 1.000e-05, It/sec 0.303, Tokens/sec 540.497, Trained Tokens 95568, Peak mem 12.744 GB
Iter 77: Train loss 1.440, Learning Rate 1.000e-05, It/sec 0.745, Tokens/sec 614.929, Trained Tokens 96393, Peak mem 12.744 GB
Iter 78: Train loss 1.558, Learning Rate 1.000e-05, It/sec 0.600, Tokens/sec 613.459, Trained Tokens 97416, Peak mem 12.744 GB
Iter 79: Train loss 1.480, Learning Rate 1.000e-05, It/sec 0.354, Tokens/sec 490.311, Trained Tokens 98801, Peak mem 12.744 GB
Iter 80: Train loss 1.417, Learning Rate 1.000e-05, It/sec 0.825, Tokens/sec 599.015, Trained Tokens 99527, Peak mem 12.744 GB
Iter 80: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000080_adapters.safetensors.
Iter 81: Train loss 1.021, Learning Rate 1.000e-05, It/sec 0.367, Tokens/sec 352.983, Trained Tokens 100488, Peak mem 12.744 GB
Iter 82: Train loss 1.551, Learning Rate 1.000e-05, It/sec 0.293, Tokens/sec 576.772, Trained Tokens 102455, Peak mem 12.744 GB
Iter 83: Train loss 1.587, Learning Rate 1.000e-05, It/sec 0.505, Tokens/sec 607.992, Trained Tokens 103659, Peak mem 12.744 GB
Iter 84: Train loss 1.675, Learning Rate 1.000e-05, It/sec 0.917, Tokens/sec 589.006, Trained Tokens 104301, Peak mem 12.744 GB
Iter 85: Train loss 1.498, Learning Rate 1.000e-05, It/sec 0.612, Tokens/sec 602.992, Trained Tokens 105287, Peak mem 12.744 GB
Iter 86: Train loss 1.832, Learning Rate 1.000e-05, It/sec 0.646, Tokens/sec 611.999, Trained Tokens 106235, Peak mem 12.744 GB
Iter 87: Train loss 1.497, Learning Rate 1.000e-05, It/sec 0.457, Tokens/sec 610.228, Trained Tokens 107570, Peak mem 12.744 GB
Iter 88: Train loss 1.059, Learning Rate 1.000e-05, It/sec 0.310, Tokens/sec 437.320, Trained Tokens 108980, Peak mem 12.744 GB
Iter 89: Train loss 1.317, Learning Rate 1.000e-05, It/sec 0.565, Tokens/sec 608.608, Trained Tokens 110057, Peak mem 12.744 GB
Iter 90: Train loss 1.514, Learning Rate 1.000e-05, It/sec 0.255, Tokens/sec 437.049, Trained Tokens 111774, Peak mem 12.744 GB
Iter 90: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000090_adapters.safetensors.
Iter 91: Train loss 1.480, Learning Rate 1.000e-05, It/sec 0.547, Tokens/sec 601.086, Trained Tokens 112873, Peak mem 12.744 GB
Iter 92: Train loss 1.635, Learning Rate 1.000e-05, It/sec 0.456, Tokens/sec 441.178, Trained Tokens 113840, Peak mem 12.744 GB
Iter 93: Train loss 1.270, Learning Rate 1.000e-05, It/sec 0.567, Tokens/sec 605.699, Trained Tokens 114909, Peak mem 12.744 GB
Iter 94: Train loss 1.615, Learning Rate 1.000e-05, It/sec 0.289, Tokens/sec 487.284, Trained Tokens 116595, Peak mem 12.744 GB
Iter 95: Train loss 1.383, Learning Rate 1.000e-05, It/sec 0.254, Tokens/sec 397.945, Trained Tokens 118162, Peak mem 12.744 GB
Iter 96: Train loss 1.099, Learning Rate 1.000e-05, It/sec 0.545, Tokens/sec 620.016, Trained Tokens 119299, Peak mem 12.744 GB
Iter 97: Train loss 1.660, Learning Rate 1.000e-05, It/sec 0.465, Tokens/sec 597.845, Trained Tokens 120585, Peak mem 12.744 GB
Iter 98: Train loss 1.412, Learning Rate 1.000e-05, It/sec 0.402, Tokens/sec 593.111, Trained Tokens 122061, Peak mem 12.744 GB
Iter 99: Train loss 1.752, Learning Rate 1.000e-05, It/sec 0.441, Tokens/sec 596.284, Trained Tokens 123412, Peak mem 12.744 GB
Iter 100: Val loss 1.590, Val took 1.321s

Adam

Iter 1: Val loss 1.638, Val took 1.709s
Iter 1: Train loss 1.466, Learning Rate 1.000e-05, It/sec 0.972, Tokens/sec 629.107, Trained Tokens 647, Peak mem 5.860 GB
Iter 2: Train loss 1.189, Learning Rate 1.000e-05, It/sec 0.177, Tokens/sec 575.080, Trained Tokens 3902, Peak mem 12.744 GB
Iter 3: Train loss 1.398, Learning Rate 1.000e-05, It/sec 0.356, Tokens/sec 623.304, Trained Tokens 5652, Peak mem 12.744 GB
Iter 4: Train loss 2.225, Learning Rate 1.000e-05, It/sec 1.811, Tokens/sec 628.255, Trained Tokens 5999, Peak mem 12.744 GB
Iter 5: Train loss 1.511, Learning Rate 1.000e-05, It/sec 0.319, Tokens/sec 582.413, Trained Tokens 7825, Peak mem 12.744 GB
Iter 6: Train loss 1.445, Learning Rate 1.000e-05, It/sec 0.613, Tokens/sec 650.911, Trained Tokens 8886, Peak mem 12.744 GB
Iter 7: Train loss 1.593, Learning Rate 1.000e-05, It/sec 0.881, Tokens/sec 671.453, Trained Tokens 9648, Peak mem 12.744 GB
Iter 8: Train loss 1.786, Learning Rate 1.000e-05, It/sec 0.601, Tokens/sec 564.607, Trained Tokens 10587, Peak mem 12.744 GB
Iter 9: Train loss 1.756, Learning Rate 1.000e-05, It/sec 0.645, Tokens/sec 656.109, Trained Tokens 11605, Peak mem 12.744 GB
Iter 10: Train loss 1.019, Learning Rate 1.000e-05, It/sec 0.445, Tokens/sec 632.010, Trained Tokens 13025, Peak mem 12.744 GB
Iter 10: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000010_adapters.safetensors.
Iter 11: Train loss 1.154, Learning Rate 1.000e-05, It/sec 0.481, Tokens/sec 635.520, Trained Tokens 14346, Peak mem 12.744 GB
Iter 12: Train loss 1.377, Learning Rate 1.000e-05, It/sec 0.444, Tokens/sec 552.467, Trained Tokens 15590, Peak mem 12.744 GB
Iter 13: Train loss 1.722, Learning Rate 1.000e-05, It/sec 0.256, Tokens/sec 604.657, Trained Tokens 17956, Peak mem 12.744 GB
Iter 14: Train loss 1.442, Learning Rate 1.000e-05, It/sec 0.266, Tokens/sec 602.850, Trained Tokens 20224, Peak mem 12.744 GB
Iter 15: Train loss 1.437, Learning Rate 1.000e-05, It/sec 0.312, Tokens/sec 614.792, Trained Tokens 22194, Peak mem 12.744 GB
Iter 16: Train loss 1.387, Learning Rate 1.000e-05, It/sec 0.396, Tokens/sec 633.014, Trained Tokens 23793, Peak mem 12.744 GB
Iter 17: Train loss 1.538, Learning Rate 1.000e-05, It/sec 0.573, Tokens/sec 625.796, Trained Tokens 24886, Peak mem 12.744 GB
Iter 18: Train loss 0.874, Learning Rate 1.000e-05, It/sec 0.596, Tokens/sec 635.225, Trained Tokens 25951, Peak mem 12.744 GB
Iter 19: Train loss 1.410, Learning Rate 1.000e-05, It/sec 0.443, Tokens/sec 596.280, Trained Tokens 27296, Peak mem 12.744 GB
Iter 20: Train loss 1.700, Learning Rate 1.000e-05, It/sec 0.423, Tokens/sec 609.029, Trained Tokens 28735, Peak mem 12.744 GB
Iter 20: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000020_adapters.safetensors.
Iter 21: Train loss 1.297, Learning Rate 1.000e-05, It/sec 0.853, Tokens/sec 622.047, Trained Tokens 29464, Peak mem 12.744 GB
Iter 22: Train loss 1.664, Learning Rate 1.000e-05, It/sec 0.492, Tokens/sec 599.321, Trained Tokens 30683, Peak mem 12.744 GB
Iter 23: Train loss 1.282, Learning Rate 1.000e-05, It/sec 0.374, Tokens/sec 563.105, Trained Tokens 32190, Peak mem 12.744 GB
Iter 24: Train loss 1.677, Learning Rate 1.000e-05, It/sec 0.788, Tokens/sec 613.112, Trained Tokens 32968, Peak mem 12.744 GB
Iter 25: Train loss 1.898, Learning Rate 1.000e-05, It/sec 1.017, Tokens/sec 628.653, Trained Tokens 33586, Peak mem 12.744 GB
Iter 26: Train loss 1.379, Learning Rate 1.000e-05, It/sec 0.311, Tokens/sec 587.401, Trained Tokens 35475, Peak mem 12.744 GB
Iter 27: Train loss 0.933, Learning Rate 1.000e-05, It/sec 0.712, Tokens/sec 619.114, Trained Tokens 36344, Peak mem 12.744 GB
Iter 28: Train loss 1.784, Learning Rate 1.000e-05, It/sec 1.268, Tokens/sec 637.579, Trained Tokens 36847, Peak mem 12.744 GB
Iter 29: Train loss 1.300, Learning Rate 1.000e-05, It/sec 0.370, Tokens/sec 475.946, Trained Tokens 38134, Peak mem 12.744 GB
Iter 30: Train loss 1.412, Learning Rate 1.000e-05, It/sec 0.338, Tokens/sec 588.921, Trained Tokens 39877, Peak mem 12.744 GB
Iter 30: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000030_adapters.safetensors.
Iter 31: Train loss 1.183, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 528.035, Trained Tokens 42028, Peak mem 12.744 GB
Iter 32: Train loss 1.611, Learning Rate 1.000e-05, It/sec 0.370, Tokens/sec 464.962, Trained Tokens 43283, Peak mem 12.744 GB
Iter 33: Train loss 1.605, Learning Rate 1.000e-05, It/sec 0.491, Tokens/sec 614.706, Trained Tokens 44534, Peak mem 12.744 GB
Iter 34: Train loss 1.249, Learning Rate 1.000e-05, It/sec 0.845, Tokens/sec 609.598, Trained Tokens 45255, Peak mem 12.744 GB
Iter 35: Train loss 1.471, Learning Rate 1.000e-05, It/sec 0.458, Tokens/sec 610.400, Trained Tokens 46588, Peak mem 12.744 GB
Iter 36: Train loss 1.621, Learning Rate 1.000e-05, It/sec 0.742, Tokens/sec 580.700, Trained Tokens 47371, Peak mem 12.744 GB
Iter 37: Train loss 1.530, Learning Rate 1.000e-05, It/sec 0.684, Tokens/sec 526.483, Trained Tokens 48141, Peak mem 12.744 GB
Iter 38: Train loss 1.958, Learning Rate 1.000e-05, It/sec 0.659, Tokens/sec 623.280, Trained Tokens 49087, Peak mem 12.744 GB
Iter 39: Train loss 1.742, Learning Rate 1.000e-05, It/sec 0.519, Tokens/sec 511.139, Trained Tokens 50072, Peak mem 12.744 GB
Iter 40: Train loss 1.761, Learning Rate 1.000e-05, It/sec 0.489, Tokens/sec 506.599, Trained Tokens 51108, Peak mem 12.744 GB
Iter 40: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000040_adapters.safetensors.
Iter 41: Train loss 1.426, Learning Rate 1.000e-05, It/sec 0.716, Tokens/sec 425.871, Trained Tokens 51703, Peak mem 12.744 GB
Iter 42: Train loss 1.010, Learning Rate 1.000e-05, It/sec 0.415, Tokens/sec 466.172, Trained Tokens 52826, Peak mem 12.744 GB
Iter 43: Train loss 1.331, Learning Rate 1.000e-05, It/sec 0.393, Tokens/sec 564.757, Trained Tokens 54264, Peak mem 12.744 GB
Iter 44: Train loss 1.583, Learning Rate 1.000e-05, It/sec 0.621, Tokens/sec 597.810, Trained Tokens 55227, Peak mem 12.744 GB
Iter 45: Train loss 1.569, Learning Rate 1.000e-05, It/sec 0.358, Tokens/sec 248.046, Trained Tokens 55920, Peak mem 12.744 GB
Iter 46: Train loss 1.694, Learning Rate 1.000e-05, It/sec 0.694, Tokens/sec 559.991, Trained Tokens 56727, Peak mem 12.744 GB
Iter 47: Train loss 1.692, Learning Rate 1.000e-05, It/sec 0.636, Tokens/sec 564.235, Trained Tokens 57614, Peak mem 12.744 GB
Iter 48: Train loss 1.399, Learning Rate 1.000e-05, It/sec 0.646, Tokens/sec 617.509, Trained Tokens 58570, Peak mem 12.744 GB
Iter 49: Train loss 1.451, Learning Rate 1.000e-05, It/sec 0.242, Tokens/sec 547.274, Trained Tokens 60827, Peak mem 12.744 GB
Iter 50: Val loss 1.704, Val took 2.831s
Iter 50: Train loss 1.662, Learning Rate 1.000e-05, It/sec 0.505, Tokens/sec 608.072, Trained Tokens 62032, Peak mem 12.744 GB
Iter 50: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000050_adapters.safetensors.
Iter 51: Train loss 1.195, Learning Rate 1.000e-05, It/sec 0.655, Tokens/sec 617.362, Trained Tokens 62975, Peak mem 12.744 GB
Iter 52: Train loss 1.104, Learning Rate 1.000e-05, It/sec 0.438, Tokens/sec 599.338, Trained Tokens 64344, Peak mem 12.744 GB
Iter 53: Train loss 0.998, Learning Rate 1.000e-05, It/sec 0.603, Tokens/sec 580.070, Trained Tokens 65306, Peak mem 12.744 GB
Iter 54: Train loss 1.410, Learning Rate 1.000e-05, It/sec 0.296, Tokens/sec 382.980, Trained Tokens 66598, Peak mem 12.744 GB
Iter 55: Train loss 1.631, Learning Rate 1.000e-05, It/sec 0.673, Tokens/sec 594.911, Trained Tokens 67482, Peak mem 12.744 GB
Iter 56: Train loss 1.277, Learning Rate 1.000e-05, It/sec 0.469, Tokens/sec 532.733, Trained Tokens 68619, Peak mem 12.744 GB
Iter 57: Train loss 1.476, Learning Rate 1.000e-05, It/sec 0.241, Tokens/sec 473.725, Trained Tokens 70588, Peak mem 12.744 GB
Iter 58: Train loss 1.381, Learning Rate 1.000e-05, It/sec 0.458, Tokens/sec 586.469, Trained Tokens 71869, Peak mem 12.744 GB
Iter 59: Train loss 1.023, Learning Rate 1.000e-05, It/sec 0.396, Tokens/sec 594.025, Trained Tokens 73369, Peak mem 12.744 GB
Iter 60: Train loss 1.503, Learning Rate 1.000e-05, It/sec 1.126, Tokens/sec 622.941, Trained Tokens 73922, Peak mem 12.744 GB
Iter 60: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000060_adapters.safetensors.
Iter 61: Train loss 1.274, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 566.824, Trained Tokens 76239, Peak mem 12.744 GB
Iter 62: Train loss 1.542, Learning Rate 1.000e-05, It/sec 0.704, Tokens/sec 614.720, Trained Tokens 77112, Peak mem 12.744 GB
Iter 63: Train loss 1.327, Learning Rate 1.000e-05, It/sec 0.318, Tokens/sec 574.726, Trained Tokens 78919, Peak mem 12.744 GB
Iter 64: Train loss 1.936, Learning Rate 1.000e-05, It/sec 1.453, Tokens/sec 618.989, Trained Tokens 79345, Peak mem 12.744 GB
Iter 65: Train loss 1.261, Learning Rate 1.000e-05, It/sec 0.381, Tokens/sec 593.783, Trained Tokens 80904, Peak mem 12.744 GB
Iter 66: Train loss 1.509, Learning Rate 1.000e-05, It/sec 1.016, Tokens/sec 596.523, Trained Tokens 81491, Peak mem 12.744 GB
Iter 67: Train loss 1.472, Learning Rate 1.000e-05, It/sec 0.394, Tokens/sec 596.880, Trained Tokens 83006, Peak mem 12.744 GB
Iter 68: Train loss 1.413, Learning Rate 1.000e-05, It/sec 0.318, Tokens/sec 577.550, Trained Tokens 84824, Peak mem 12.744 GB
Iter 69: Train loss 1.605, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 599.301, Trained Tokens 86090, Peak mem 12.744 GB
Iter 70: Train loss 1.450, Learning Rate 1.000e-05, It/sec 0.610, Tokens/sec 599.262, Trained Tokens 87073, Peak mem 12.744 GB
Iter 70: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000070_adapters.safetensors.
Iter 71: Train loss 1.467, Learning Rate 1.000e-05, It/sec 0.298, Tokens/sec 573.052, Trained Tokens 88996, Peak mem 12.744 GB
Iter 72: Train loss 1.228, Learning Rate 1.000e-05, It/sec 0.382, Tokens/sec 598.298, Trained Tokens 90562, Peak mem 12.744 GB
Iter 73: Train loss 2.151, Learning Rate 1.000e-05, It/sec 1.282, Tokens/sec 587.172, Trained Tokens 91020, Peak mem 12.744 GB
Iter 74: Train loss 1.185, Learning Rate 1.000e-05, It/sec 0.384, Tokens/sec 436.397, Trained Tokens 92157, Peak mem 12.744 GB
Iter 75: Train loss 1.664, Learning Rate 1.000e-05, It/sec 0.361, Tokens/sec 586.634, Trained Tokens 93783, Peak mem 12.744 GB
Iter 76: Train loss 1.446, Learning Rate 1.000e-05, It/sec 0.300, Tokens/sec 536.035, Trained Tokens 95568, Peak mem 12.744 GB
Iter 77: Train loss 1.337, Learning Rate 1.000e-05, It/sec 0.748, Tokens/sec 617.211, Trained Tokens 96393, Peak mem 12.744 GB
Iter 78: Train loss 1.480, Learning Rate 1.000e-05, It/sec 0.600, Tokens/sec 614.257, Trained Tokens 97416, Peak mem 12.744 GB
Iter 79: Train loss 1.411, Learning Rate 1.000e-05, It/sec 0.359, Tokens/sec 496.979, Trained Tokens 98801, Peak mem 12.744 GB
Iter 80: Train loss 1.302, Learning Rate 1.000e-05, It/sec 0.844, Tokens/sec 612.869, Trained Tokens 99527, Peak mem 12.744 GB
Iter 80: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000080_adapters.safetensors.
Iter 81: Train loss 0.899, Learning Rate 1.000e-05, It/sec 0.491, Tokens/sec 472.231, Trained Tokens 100488, Peak mem 12.744 GB
Iter 82: Train loss 1.505, Learning Rate 1.000e-05, It/sec 0.180, Tokens/sec 354.323, Trained Tokens 102455, Peak mem 12.744 GB
Iter 83: Train loss 1.514, Learning Rate 1.000e-05, It/sec 0.506, Tokens/sec 609.009, Trained Tokens 103659, Peak mem 12.744 GB
Iter 84: Train loss 1.533, Learning Rate 1.000e-05, It/sec 0.835, Tokens/sec 536.039, Trained Tokens 104301, Peak mem 12.744 GB
Iter 85: Train loss 1.409, Learning Rate 1.000e-05, It/sec 0.612, Tokens/sec 603.270, Trained Tokens 105287, Peak mem 12.744 GB
Iter 86: Train loss 1.699, Learning Rate 1.000e-05, It/sec 0.654, Tokens/sec 620.038, Trained Tokens 106235, Peak mem 12.744 GB
Iter 87: Train loss 1.415, Learning Rate 1.000e-05, It/sec 0.443, Tokens/sec 591.791, Trained Tokens 107570, Peak mem 12.744 GB
Iter 88: Train loss 0.965, Learning Rate 1.000e-05, It/sec 0.368, Tokens/sec 518.771, Trained Tokens 108980, Peak mem 12.744 GB
Iter 89: Train loss 1.227, Learning Rate 1.000e-05, It/sec 0.338, Tokens/sec 363.499, Trained Tokens 110057, Peak mem 12.744 GB
Iter 90: Train loss 1.449, Learning Rate 1.000e-05, It/sec 0.347, Tokens/sec 595.435, Trained Tokens 111774, Peak mem 12.744 GB
Iter 90: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000090_adapters.safetensors.
Iter 91: Train loss 1.358, Learning Rate 1.000e-05, It/sec 0.544, Tokens/sec 597.848, Trained Tokens 112873, Peak mem 12.744 GB
Iter 92: Train loss 1.535, Learning Rate 1.000e-05, It/sec 0.510, Tokens/sec 493.189, Trained Tokens 113840, Peak mem 12.744 GB
Iter 93: Train loss 1.172, Learning Rate 1.000e-05, It/sec 0.569, Tokens/sec 608.483, Trained Tokens 114909, Peak mem 12.744 GB
Iter 94: Train loss 1.540, Learning Rate 1.000e-05, It/sec 0.350, Tokens/sec 590.328, Trained Tokens 116595, Peak mem 12.744 GB
Iter 95: Train loss 1.309, Learning Rate 1.000e-05, It/sec 0.382, Tokens/sec 599.283, Trained Tokens 118162, Peak mem 12.744 GB
Iter 96: Train loss 1.005, Learning Rate 1.000e-05, It/sec 0.543, Tokens/sec 617.826, Trained Tokens 119299, Peak mem 12.744 GB
Iter 97: Train loss 1.548, Learning Rate 1.000e-05, It/sec 0.466, Tokens/sec 599.246, Trained Tokens 120585, Peak mem 12.744 GB
Iter 98: Train loss 1.339, Learning Rate 1.000e-05, It/sec 0.401, Tokens/sec 592.229, Trained Tokens 122061, Peak mem 12.744 GB
Iter 99: Train loss 1.658, Learning Rate 1.000e-05, It/sec 0.443, Tokens/sec 598.112, Trained Tokens 123412, Peak mem 12.744 GB
Iter 100: Val loss 1.505, Val took 1.323s
Iter 100: Train loss 1.380, Learning Rate 1.000e-05, It/sec 0.363, Tokens/sec 446.054, Trained Tokens 124641, Peak mem 12.744 GB

Goekdeniz-Guelmez avatar Feb 28 '25 22:02 Goekdeniz-Guelmez

That is definitely interesting but I think https://github.com/stockeh/mlx-optimizers may be a more suitable repository. Wdyt?

angeloskath avatar Mar 01 '25 01:03 angeloskath

@Goekdeniz-Guelmez 🔥 Perfect timing! The Muon optimizer just dropped, and now it’s already in MLX!!! pure optimization wizardry.

lin72h avatar Mar 01 '25 02:03 lin72h

@Goekdeniz-Guelmez @angeloskath yes, we have Muon already:

https://github.com/stockeh/mlx-optimizers/blob/main/mlx_optimizers/muon.py

thought I do believe Keller Jordan had made some minor updates since.

stockeh avatar Mar 01 '25 06:03 stockeh

@stockeh I didn't new the optimiser repo existed :D. But yea there are some differences with the new one. The new maintains the same mathematical principles but extends support to higher-dimensional tensors like conv filters through reshaping rather than using a separate optimizer. Also improves efficiency with a streamlined Newton-Schulz iteration formula and applies weight decay earlier in optimization process. The code now handles non-2D parameters more consistently and uses generalized transpose and normalization logic, works with tensors of any dimensionality.

Goekdeniz-Guelmez avatar Mar 01 '25 10:03 Goekdeniz-Guelmez

hi, @stockeh

We recently worked on Muon and released the Moonlight model, see (https://github.com/MoonshotAI/Moonlight/tree/master). We had some empirical observations for muon to scale (and we did not see it in current implementation), and hope you do not mind me sharing it here:

  1. introducing weight decay, otherwise your weight rms might be too big when over-trained;
  2. adjusting the update rms based on the matrix shape, otherwise your model weights will not have consistent update RMS. This line (https://github.com/stockeh/mlx-optimizers/blob/main/mlx_optimizers/muon.py#L104) might be dangerous because it has a strong assumption to work under the nanoGPT setting.

The implementation is easy, see an example here: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L197-L203

These suggestions are empirically helpful to over-train as observed during our pretraining on Moonlight. What are your guys' opinions?

toothacher17 avatar Mar 01 '25 13:03 toothacher17

@toothacher17 Wow! The Moonlight team just popped in with actual scaling tips! 🚀 Love seeing them share those crucial details about weight decay and matrix shape adjustments. This is what makes open source so awesome - experts freely sharing knowledge that turns theory into production-ready code. MLX bringing ML minds together at its finest!

lin72h avatar Mar 01 '25 13:03 lin72h

@lin72h @toothacher17 I agree I did not see that coming, but its very welcome.

Goekdeniz-Guelmez avatar Mar 01 '25 13:03 Goekdeniz-Guelmez

@lin72h @Goekdeniz-Guelmez

Thanks! Our team, Moonshot AI, believed that Muon is scalable, and did some empirical experiments! In case you guys are interested, we have a tech report discussing it: https://arxiv.org/abs/2502.16982

toothacher17 avatar Mar 01 '25 13:03 toothacher17

@toothacher17 Just read the Moonshot paper - same as K1 paper and even more innovative than DeepSeek's work! you folks at Moonshot haven't gotten the attention you deserve - your work just got overshadowed by DeepSeek's timing. The ideas in Moonlight are truly incredible. Open-sourcing this level of innovation is something to be genuinely proud of. 华人骄傲!

lin72h avatar Mar 01 '25 13:03 lin72h

@lin72h Thanks a lot for the kind words! Deepseek is truly doing a great job! I personally very much admire their contributions to push forward the progress of open source and AGI. We are a humble team and we'll keep work hard to deliver and publish good stuffs!

toothacher17 avatar Mar 01 '25 14:03 toothacher17

@toothacher17 Keep up the awesome work! Moonshot rocks!

lin72h avatar Mar 01 '25 14:03 lin72h

@toothacher17 I appreciate you sharing your insights! I found the paper to be quiet informative.

I think most of these changes are easy enough to implement into mlx-optimizers. Although, I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW (for example), considering the difference in how torch and mlx optimizers are initialized. That said, I'm happy to add wd and scaling, as you have, with what's in there now!

stockeh avatar Mar 01 '25 14:03 stockeh

I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW

Can you say more about that?

awni avatar Mar 01 '25 19:03 awni

I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW

Can you say more about that?

I guess that's because for now, AdamW is chained with Muon to handle those non-matrix parameters, e.g. embedding, lm head, and rmsnorm gamma. In future, there might be a chance to get rid of AdamW and only use Muon purely, for example: https://github.com/modula-systems/modula It's not large scale proven yet, but it might be promising

toothacher17 avatar Mar 02 '25 00:03 toothacher17

I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW

Can you say more about that?

@awni tldr: I don't think anything has to change with mlx, specifically, but I may change mlx-optimizers' Muon class to not include AdamW and simplify the delegation logic with a separate optim.

I originally said this when thinking about how we pass params to the optimizer, e.g., in KellerJordan/Muon

muon_params = [p for p in model.body.parameters() if p.ndim >= 2]
adamw_params = ([p for p in model.body.parameters() if p.ndim < 2]
              + [*model.head.parameters(), *model.embed.parameters()])
optimizers = [Muon(muon_params, lr=0.02, momentum=0.95),
              torch.optim.AdamW(adamw_params, lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)]
...
# in the training step
for opt in optimizers:
    opt.step()

Moonlight's implementation differs in that their custom Muon class accepts both muon_params and adamw_params to have only one optimizer. This kind of logic is a bit more challenging to generalize in mlx if we wanted to have more custom rules, e.g., set of layer names and dimensionality.

But, I thought about this some more and think it's easier as a general approach to just define multiple optimizers as we've discussed in this discussion, i.e.,

def split_grads(grads):
    grads = tree_flatten(grads)
    weights = [(k, v) for k, v in grads if v.ndim == 2]
    biases = [(k, v) for k, v in grads if v.ndim == 1]
    weights = tree_unflatten(weights)
    biases = tree_unflatten(biases)
    return weights, biases

@partial(mx.compile, inputs=state, outputs=state)
def step(X, T):
    train_step_fn = nn.value_and_grad(self.model, self.eval_fn)
    loss, grads = train_step_fn(X, T)
    weights, biases = split_grads(grads)
    self.optimizers[0].update(self.model, weights)
    self.optimizers[1].update(self.model, biases)
    return loss

This would just require a bit of a refactor and description for using Muon in mlx-optimizers, should the optims be separate.

stockeh avatar Mar 03 '25 01:03 stockeh

Thanks for the detailed explanation, that makes sense!

awni avatar Mar 03 '25 01:03 awni

Hey @awni I just wanted to quickly check if this PR is ready to be merged, or is there anything you’d like me to adjust?

Goekdeniz-Guelmez avatar Apr 21 '25 18:04 Goekdeniz-Guelmez

@Goekdeniz-Guelmez I'm not convinced it makes sense to merge this given that it's pretty niche/new and also already in mlx-optimizers, see here.

Anyone have any thoughts on that? If someone has a strong case for why we should include it that would be interesting to hear.. o/w I'd recommend we close this PR for now.

awni avatar Apr 22 '25 23:04 awni

Maybe it's time to revisit adding this to core? Wdyt? Seems like the optimizer is getting pretty popular.

awni avatar Jul 16 '25 02:07 awni

@awni I have similarly been seeing popularity grow and support adding it!

stockeh avatar Jul 16 '25 03:07 stockeh

I also think that, due to the KimiK2 release and also from my own experience, this optmimzer is really good. Should i reopen this PR?

Goekdeniz-Guelmez avatar Jul 16 '25 06:07 Goekdeniz-Guelmez

Yes let's re-open it!

awni avatar Jul 16 '25 13:07 awni

Is mergable imo:

--model mlx-community/Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1 \
--train \
--test \
--data mlx-community/wikisql \
--fine-tune-type lora \
--optimizer muon \
--batch-size 1 \
--val-batches -1 \
--test-batches -1 \
--iters 1000 \
--grad-checkpoint \
--learning-rate 0.00002 \
--steps-per-report 1 \
--steps-per-eval 500 \
--adapter-path /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/adapter \
--save-every 100 \
--max-seq-length 2065

Muon:

Calculating loss...: 100it [00:04, 22.41it/s]
Iter 1: Val loss 2.762, Val took 4.463s
Iter 1: Train loss 3.078, Learning Rate 2.000e-05, It/sec 6.934, Tokens/sec 540.863, Trained Tokens 78, Peak mem 1.178 GB
Iter 2: Train loss 3.222, Learning Rate 2.000e-05, It/sec 6.558, Tokens/sec 524.610, Trained Tokens 158, Peak mem 1.180 GB
Iter 3: Train loss 1.723, Learning Rate 2.000e-05, It/sec 6.587, Tokens/sec 678.507, Trained Tokens 261, Peak mem 1.242 GB
Iter 4: Train loss 3.057, Learning Rate 2.000e-05, It/sec 8.465, Tokens/sec 558.691, Trained Tokens 327, Peak mem 1.242 GB
Iter 5: Train loss 3.067, Learning Rate 2.000e-05, It/sec 8.363, Tokens/sec 635.624, Trained Tokens 403, Peak mem 1.242 GB
Iter 6: Train loss 2.954, Learning Rate 2.000e-05, It/sec 8.236, Tokens/sec 741.208, Trained Tokens 493, Peak mem 1.242 GB
Iter 7: Train loss 2.936, Learning Rate 2.000e-05, It/sec 8.296, Tokens/sec 638.772, Trained Tokens 570, Peak mem 1.242 GB
Iter 8: Train loss 2.974, Learning Rate 2.000e-05, It/sec 8.265, Tokens/sec 586.825, Trained Tokens 641, Peak mem 1.242 GB
Iter 9: Train loss 2.643, Learning Rate 2.000e-05, It/sec 8.225, Tokens/sec 748.460, Trained Tokens 732, Peak mem 1.242 GB
Iter 10: Train loss 2.865, Learning Rate 2.000e-05, It/sec 8.192, Tokens/sec 704.551, Trained Tokens 818, Peak mem 1.242 GB
Iter 11: Train loss 3.036, Learning Rate 2.000e-05, It/sec 8.181, Tokens/sec 703.578, Trained Tokens 904, Peak mem 1.242 GB
Iter 12: Train loss 3.161, Learning Rate 2.000e-05, It/sec 8.617, Tokens/sec 637.651, Trained Tokens 978, Peak mem 1.242 GB
Iter 13: Train loss 1.611, Learning Rate 2.000e-05, It/sec 5.130, Tokens/sec 872.085, Trained Tokens 1148, Peak mem 1.372 GB
Iter 14: Train loss 3.123, Learning Rate 2.000e-05, It/sec 9.094, Tokens/sec 536.572, Trained Tokens 1207, Peak mem 1.372 GB
Iter 15: Train loss 2.610, Learning Rate 2.000e-05, It/sec 8.369, Tokens/sec 811.774, Trained Tokens 1304, Peak mem 1.372 GB
Iter 16: Train loss 2.970, Learning Rate 2.000e-05, It/sec 8.399, Tokens/sec 730.710, Trained Tokens 1391, Peak mem 1.372 GB
Iter 17: Train loss 2.620, Learning Rate 2.000e-05, It/sec 8.348, Tokens/sec 793.087, Trained Tokens 1486, Peak mem 1.372 GB
Iter 18: Train loss 2.524, Learning Rate 2.000e-05, It/sec 8.214, Tokens/sec 607.861, Trained Tokens 1560, Peak mem 1.372 GB
Iter 19: Train loss 2.223, Learning Rate 2.000e-05, It/sec 7.994, Tokens/sec 799.390, Trained Tokens 1660, Peak mem 1.372 GB
Iter 20: Train loss 2.869, Learning Rate 2.000e-05, It/sec 7.956, Tokens/sec 787.597, Trained Tokens 1759, Peak mem 1.372 GB
...
Iter 990: Train loss 1.255, Learning Rate 2.000e-05, It/sec 8.303, Tokens/sec 672.557, Trained Tokens 89568, Peak mem 1.443 GB
Iter 991: Train loss 1.327, Learning Rate 2.000e-05, It/sec 7.965, Tokens/sec 931.876, Trained Tokens 89685, Peak mem 1.443 GB
Iter 992: Train loss 1.319, Learning Rate 2.000e-05, It/sec 8.249, Tokens/sec 816.607, Trained Tokens 89784, Peak mem 1.443 GB
Iter 993: Train loss 1.274, Learning Rate 2.000e-05, It/sec 8.373, Tokens/sec 577.714, Trained Tokens 89853, Peak mem 1.443 GB
Iter 994: Train loss 2.662, Learning Rate 2.000e-05, It/sec 5.988, Tokens/sec 844.356, Trained Tokens 89994, Peak mem 1.443 GB
Iter 995: Train loss 1.634, Learning Rate 2.000e-05, It/sec 5.894, Tokens/sec 783.891, Trained Tokens 90127, Peak mem 1.443 GB
Iter 996: Train loss 1.691, Learning Rate 2.000e-05, It/sec 8.037, Tokens/sec 747.480, Trained Tokens 90220, Peak mem 1.443 GB
Iter 997: Train loss 1.862, Learning Rate 2.000e-05, It/sec 8.145, Tokens/sec 749.352, Trained Tokens 90312, Peak mem 1.443 GB
Iter 998: Train loss 1.057, Learning Rate 2.000e-05, It/sec 7.849, Tokens/sec 777.074, Trained Tokens 90411, Peak mem 1.443 GB
Iter 999: Train loss 1.148, Learning Rate 2.000e-05, It/sec 8.165, Tokens/sec 587.876, Trained Tokens 90483, Peak mem 1.443 GB
Calculating loss...: 100it [00:04, 21.57it/s]
Iter 1000: Val loss 1.353, Val took 4.639s
Iter 1000: Train loss 1.373, Learning Rate 2.000e-05, It/sec 7.173, Tokens/sec 817.668, Trained Tokens 90597, Peak mem 1.443 GB
Iter 1000: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/adapter/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/adapter/0001000_adapters.safetensors.
Saved final weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/adapter/adapters.safetensors.
Testing
Calculating loss...: 100it [00:04, 21.76it/s]
Test loss 1.592, Test ppl 4.916.

adam:

Iter 1: Val loss 2.762, Val took 4.629s
Iter 1: Train loss 3.078, Learning Rate 2.000e-05, It/sec 8.597, Tokens/sec 670.546, Trained Tokens 78, Peak mem 1.178 GB
Iter 2: Train loss 3.149, Learning Rate 2.000e-05, It/sec 8.613, Tokens/sec 689.036, Trained Tokens 158, Peak mem 1.183 GB
Iter 3: Train loss 1.588, Learning Rate 2.000e-05, It/sec 8.330, Tokens/sec 857.991, Trained Tokens 261, Peak mem 1.245 GB
Iter 4: Train loss 2.716, Learning Rate 2.000e-05, It/sec 9.896, Tokens/sec 653.164, Trained Tokens 327, Peak mem 1.245 GB
Iter 5: Train loss 2.664, Learning Rate 2.000e-05, It/sec 9.097, Tokens/sec 691.350, Trained Tokens 403, Peak mem 1.245 GB
Iter 6: Train loss 2.508, Learning Rate 2.000e-05, It/sec 8.931, Tokens/sec 803.776, Trained Tokens 493, Peak mem 1.245 GB
Iter 7: Train loss 2.234, Learning Rate 2.000e-05, It/sec 9.184, Tokens/sec 707.149, Trained Tokens 570, Peak mem 1.245 GB
Iter 8: Train loss 2.148, Learning Rate 2.000e-05, It/sec 9.208, Tokens/sec 653.745, Trained Tokens 641, Peak mem 1.245 GB
Iter 9: Train loss 1.914, Learning Rate 2.000e-05, It/sec 9.111, Tokens/sec 829.067, Trained Tokens 732, Peak mem 1.245 GB
Iter 10: Train loss 2.117, Learning Rate 2.000e-05, It/sec 8.761, Tokens/sec 753.413, Trained Tokens 818, Peak mem 1.245 GB
Iter 11: Train loss 2.084, Learning Rate 2.000e-05, It/sec 9.087, Tokens/sec 781.472, Trained Tokens 904, Peak mem 1.245 GB
Iter 12: Train loss 2.065, Learning Rate 2.000e-05, It/sec 9.199, Tokens/sec 680.697, Trained Tokens 978, Peak mem 1.245 GB
Iter 13: Train loss 1.203, Learning Rate 2.000e-05, It/sec 5.404, Tokens/sec 918.720, Trained Tokens 1148, Peak mem 1.378 GB
Iter 14: Train loss 1.895, Learning Rate 2.000e-05, It/sec 10.859, Tokens/sec 640.659, Trained Tokens 1207, Peak mem 1.378 GB
Iter 15: Train loss 1.650, Learning Rate 2.000e-05, It/sec 9.421, Tokens/sec 913.827, Trained Tokens 1304, Peak mem 1.378 GB
Iter 16: Train loss 2.048, Learning Rate 2.000e-05, It/sec 9.660, Tokens/sec 840.452, Trained Tokens 1391, Peak mem 1.378 GB
Iter 17: Train loss 1.554, Learning Rate 2.000e-05, It/sec 9.261, Tokens/sec 879.769, Trained Tokens 1486, Peak mem 1.378 GB
Iter 18: Train loss 1.406, Learning Rate 2.000e-05, It/sec 9.092, Tokens/sec 672.836, Trained Tokens 1560, Peak mem 1.378 GB
Iter 19: Train loss 1.083, Learning Rate 2.000e-05, It/sec 8.553, Tokens/sec 855.250, Trained Tokens 1660, Peak mem 1.378 GB
Iter 20: Train loss 1.914, Learning Rate 2.000e-05, It/sec 8.492, Tokens/sec 840.691, Trained Tokens 1759, Peak mem 1.378 GB
...
Iter 990: Train loss 0.750, Learning Rate 2.000e-05, It/sec 9.093, Tokens/sec 736.566, Trained Tokens 89568, Peak mem 1.450 GB
Iter 991: Train loss 0.858, Learning Rate 2.000e-05, It/sec 8.638, Tokens/sec 1010.684, Trained Tokens 89685, Peak mem 1.450 GB
Iter 992: Train loss 1.156, Learning Rate 2.000e-05, It/sec 8.467, Tokens/sec 838.268, Trained Tokens 89784, Peak mem 1.450 GB
Iter 993: Train loss 1.139, Learning Rate 2.000e-05, It/sec 9.020, Tokens/sec 622.408, Trained Tokens 89853, Peak mem 1.450 GB
Iter 994: Train loss 2.400, Learning Rate 2.000e-05, It/sec 6.460, Tokens/sec 910.811, Trained Tokens 89994, Peak mem 1.450 GB
Iter 995: Train loss 1.364, Learning Rate 2.000e-05, It/sec 6.472, Tokens/sec 860.731, Trained Tokens 90127, Peak mem 1.450 GB
Iter 996: Train loss 1.503, Learning Rate 2.000e-05, It/sec 9.551, Tokens/sec 888.237, Trained Tokens 90220, Peak mem 1.450 GB
Iter 997: Train loss 1.912, Learning Rate 2.000e-05, It/sec 9.095, Tokens/sec 836.703, Trained Tokens 90312, Peak mem 1.450 GB
Iter 998: Train loss 0.852, Learning Rate 2.000e-05, It/sec 8.562, Tokens/sec 847.606, Trained Tokens 90411, Peak mem 1.450 GB
Iter 999: Train loss 1.021, Learning Rate 2.000e-05, It/sec 8.954, Tokens/sec 644.716, Trained Tokens 90483, Peak mem 1.450 GB
Calculating loss...: 100it [00:04, 20.44it/s]
Iter 1000: Val loss 1.201, Val took 4.898s
Iter 1000: Train loss 1.148, Learning Rate 2.000e-05, It/sec 8.861, Tokens/sec 1010.194, Trained Tokens 90597, Peak mem 1.450 GB
Iter 1000: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/adapter/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/adapter/0001000_adapters.safetensors.
Saved final weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/adapter/adapters.safetensors.
Testing
Calculating loss...: 100it [00:04, 21.41it/s]
Test loss 1.526, Test ppl 4.598.

Goekdeniz-Guelmez avatar Jul 16 '25 17:07 Goekdeniz-Guelmez

@awni @angeloskath would you also mind trying it out too?

Goekdeniz-Guelmez avatar Jul 17 '25 11:07 Goekdeniz-Guelmez

@awni nice thanks, how about now?

Goekdeniz-Guelmez avatar Jul 17 '25 18:07 Goekdeniz-Guelmez

I did a bit more simplification / nits. I think it is in good shape. What would be really nice is if we can confirm the implementation is correct with some reference. Would mind checking that? Ideally compare a single update given some random gradient matches the reference. I think if that checks out we are good to go.

awni avatar Jul 17 '25 19:07 awni

So I compared it with the OG torch implementation:

Code

Torch

import torch
import torch.nn as nn
from muon import Muon
import numpy as np

# Set seed
torch.manual_seed(42)

shared_weight = np.random.randn(4, 4).astype(np.float32)

# Define identical models
class SimpleModel(nn.Module):
    def __init__(self, weight):
        super().__init__()
        self.layer = nn.Linear(4, 4, bias=False)
        self.layer.weight.data = torch.tensor(weight)

model_muon = SimpleModel(shared_weight)
model_ref = SimpleModel(shared_weight)

# Simulate a known gradient
fake_grad = torch.randn_like(model_muon.layer.weight)
model_muon.layer.weight.grad = fake_grad.clone()
model_ref.layer.weight.grad = fake_grad.clone()

# Define optimizers
muon_opt = Muon(
    [model_muon.layer.weight],
    lr=0.02,
    weight_decay=0.0,
    momentum=0.95
)
ref_opt = torch.optim.AdamW(model_ref.parameters(), lr=0.02, betas=(0.9, 0.95), weight_decay=0.0)

# Take one optimizer step
muon_opt.step()
ref_opt.step()

# Compare weights
print("Fake gradient:\n", fake_grad)
print("\nMuon updated weights:\n", model_muon.layer.weight.data)
print("\nReference (AdamW) updated weights:\n", model_ref.layer.weight.data)

# Check closeness
diff = torch.abs(model_muon.layer.weight.data - model_ref.layer.weight.data)
print("\nAbsolute difference:\n", diff)
print("\nMax difference:", diff.max().item())

Output:

Fake gradient:
 tensor([[-1.3847, -0.8712, -0.2234,  1.7174],
        [ 0.3189, -0.4245,  0.3057, -0.7746],
        [-1.5576,  0.9956, -0.8798, -0.6011],
        [-1.2742,  2.1228, -1.2347, -0.4879]])

Muon updated weights:
 tensor([[ 1.1037,  1.5311,  0.1155, -0.5821],
        [ 0.2022,  0.4548, -0.8730,  1.3344],
        [-0.3669,  0.8592,  1.0264,  1.1324],
        [ 0.5417, -1.1792,  1.7016,  0.1289]])

Reference (AdamW) updated weights:
 tensor([[ 1.1176,  1.5436,  0.1292, -0.5934],
        [ 0.1854,  0.4632, -0.9043,  1.3455],
        [-0.3629,  0.8464,  1.0510,  1.1452],
        [ 0.5584, -1.1873,  1.7039,  0.1477]])

Absolute difference:
 tensor([[0.0139, 0.0125, 0.0137, 0.0112],
        [0.0168, 0.0084, 0.0313, 0.0111],
        [0.0041, 0.0128, 0.0246, 0.0128],
        [0.0167, 0.0081, 0.0023, 0.0187]])

Max difference: 0.03132808208465576

MLX

import mlx.core as mx
import mlx.nn as nn
from mlx.optimizers import Muon, AdamW
import numpy as np

# Set seed
np.random.seed(42)

shared_weight = np.random.randn(4, 4).astype(np.float32)

# Define identical models
class SimpleModel(nn.Module):
    def __init__(self, weight):
        super().__init__()
        self.layer = nn.Linear(4, 4, bias=False)
        self.layer.weight = mx.array(weight)

model_muon = SimpleModel(shared_weight)
model_ref = SimpleModel(shared_weight)

model_ref.update(model_muon.parameters())

# Simulate a known gradient
fake_grad = mx.array(np.random.randn(4, 4).astype(np.float32))

# Define optimizers
muon_opt = Muon(
    learning_rate=0.02,
    momentum=0.95,
    weight_decay=0.0
)

adamw_opt = AdamW(
    learning_rate=0.02,
    betas=(0.9, 0.95),
    weight_decay=0.0
)
gradients = {"layer": {"weight": fake_grad}}

adamw_opt.update(model_ref, gradients)
muon_opt.update(model_muon, gradients)

# Print everything
print("Fake gradient:\n", fake_grad)
print("\nMuon updated weights:\n", model_muon.layer.weight)
print("\nReference (AdamW) updated weights:\n", model_ref.layer.weight)
print("\nAbsolute difference:\n", mx.abs(model_muon.layer.weight - model_ref.layer.weight))
print("\nMax difference:", mx.abs(model_muon.layer.weight - model_ref.layer.weight).max().item())

Fake gradient:
 array([[-1.01283, 0.314247, -0.908024, -1.4123],
       [1.46565, -0.225776, 0.0675282, -1.42475],
       [-0.544383, 0.110923, -1.15099, 0.375698],
       [-0.600639, -0.291694, -0.601707, 1.85228]], dtype=float32)

Muon updated weights:
 array([[0.504857, -0.135187, 0.652319, 1.53351],
       [-0.246369, -0.228068, 1.58206, 0.778105],
       [-0.47041, 0.534424, -0.451733, -0.470499],
       [0.247182, -1.90243, -1.71924, -0.574942]], dtype=float32)

Reference (AdamW) updated weights:
 array([[0.505658, -0.147209, 0.656633, 1.53197],
       [-0.243098, -0.225193, 1.57027, 0.776379],
       [-0.46053, 0.533616, -0.454473, -0.474674],
       [0.250907, -1.90434, -1.71597, -0.571232]], dtype=float32)

Absolute difference:
 array([[0.000801325, 0.012022, 0.00431389, 0.00153613],
       [0.00327107, 0.00287485, 0.011796, 0.00172579],
       [0.00987995, 0.000808716, 0.0027408, 0.00417501],
       [0.00372417, 0.00190568, 0.00326455, 0.00371057]], dtype=float32)

Max difference: 0.012021973729133606

Goekdeniz-Guelmez avatar Jul 18 '25 09:07 Goekdeniz-Guelmez

its very close to the reference behavior of adamw, even closer than the og, which is funny 😄

Goekdeniz-Guelmez avatar Jul 18 '25 09:07 Goekdeniz-Guelmez