torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Integrate Muon optimizer (2725)

Open Saurabh750 opened this issue 6 months ago • 7 comments

Context

What is the purpose of this PR? Is it to

  • [x] add a new feature
  • [ ] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

Please link to any issues this PR addresses. #2725

Changelog

What are the changes made in this PR?

  • Integrating Muon optimizer as a pytorch implementation in torchtune.
  • Modify recipes accordingly.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • [ ] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [ ] update docstrings for any new or updated methods or classes
  • [ ] run unit tests via pytest tests
  • [ ] run recipe tests via pytest tests -m integration_test
  • [ ] manually run any new or modified recipes with sufficient proof of correctness
  • [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example

  • [ ] I did not change any public API
  • [ ] I have added an example to docs or docstrings

Saurabh750 avatar Jun 08 '25 16:06 Saurabh750

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2803

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure

As of commit 617d5e898ea74f8f8605a02c541d9e62ba7a5f97 with merge base b22a3ae68c9e254e7ec6afb6ae64b1b32e0f0cbe (image):

NEW FAILURE - The following job has failed:

  • Lint / lint (3.10) (gh) torchtune/modules/optim.py:11:1: F401 'torch.distributed as dist' imported but unused

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Jun 08 '25 16:06 pytorch-bot[bot]

Thanks for the first pass! Let's split single device and distributed versions of Muon in 2 separate files to improve readability. Speaking about plots: We will need some comparison performance plots against AdamW and general Wandb plots (loss) and results on evaluation recipe.

krammnic avatar Jun 08 '25 18:06 krammnic

I've added few comments, but it looks great! There are 2 things on which we might need to think though:

  1. Can we reduce the amount of "muon checks"? Maybe some special wrapper similarly to a fused optimizer?
  2. Maybe we need to implement it in a little bit cleaner way and support through builders?

krammnic avatar Jun 08 '25 20:06 krammnic

Based on above comments, 2 things came to my mind:

  1. As @krammnic suggested, I can implement a fused optimizer - a wrapper around Muon where the 2nd optimizer for linear layers will be of the choice of the user. This will eliminate the muon checks and will be a cleaner way.
  2. If we want to provide more flexibility while assigning optimizer, we can do it on parameter level. Eg: Inside config file:
optimizer:
  muon: [param1, param2]
  AdamW: []
muon:
  _component_: torchtune.modules.SingleDeviceMuon
  momentum: 0.95
  lr: 0.02
  weight_decay: 0
AdamW:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 1e-5

Muon will be assigned to only param1, param2 while AdamW will be assigned for all remaining ones. Please do let me know your views, I'll implement whatever is suitable!

Saurabh750 avatar Jun 10 '25 03:06 Saurabh750

@joecummings @krammnic : I have added Muon to optim.py I tried finetuning on alpaca dataset using qwen2-0.5B and compared AdamW and Muon. Trained for 5, 10, 20 epochs for batchsize 5 and 10. For all the experiments, AdamW performed better than Muon. I am attaching snippet of one of the experiments with batch-size 10 and 20 epochs: adamwvsMuon

  • Should I try running the same experiments on a different model? According to this, the results we are getting are on similar lines. AdamW performs better than Muon as the model was not pretrained using Muon.

There is a custom implementation of Adam in Muon class. I tried using the existing pytorch implementation with the view to use any pre-existing implementations of optimizer for linear layers. But this will not be possible due to load_state_dict() which supports only single optimizer to be stored. But, I believe load_state_dict() in OptimizerInBackward supports multiple optimizers. Please correct me if I am wrong.

I have tried reducing the muon checks, but still there is a muon check in the main file. Also, I have updated the get_lr() method for returning the lr of Muon only. Please suggest if this is correct.

Saurabh750 avatar Jun 17 '25 17:06 Saurabh750

Then, something is wrong, I don't like the fact that we have worse performance, because loss might be fixable with some HPO. Will review your changes tomorrow, to maybe find a bottleneck...

krammnic avatar Jun 17 '25 21:06 krammnic

Hi @joecummings @krammnic , I have updated the Muon optimizer implementation. In below image: Green -> Muon optimizer for first 20 epochs with 5e-4 lr Blue -> Muon optimizer for epochs 20 to 40 with 5e-5 lr image

For below image: Blue -> Muon optimizer for epochs 20 to 40 with 5e-5 lr Red -> AdamW optimizer for epochs 0 to 20 with 2e-5 lr

image

As suggested earlier, switching to another implementation and playing around with HPO helped. In the lr_scheduler, I am only returning the Muon learning rate. Please let me know if anything else needs attention.

Saurabh750 avatar Jun 23 '25 18:06 Saurabh750