MeZO icon indicating copy to clipboard operation
MeZO copied to clipboard

Add a pip-installable, simple implementation of MeZO (along with a distributed impl. and some tests)

Open lebrice opened this issue 1 year ago • 3 comments

Hello there!

I was very interested in your work after seeing it a NeurIPS. I'd like to play around with it a bit in the future. In order to do so, I felt it might be useful to add a simple, standalone implementation of your algorithm to your codebase, so people can more easily import it into their codebase and use it.

Here are my contributions, if you're interested:

  • Add a simple, readable, standalone implementation of the MeZO update in a new mezo package
    • mezo.update: Perform a single MeZO update given the model, loss function, inputs, random seed, epsilon and learning rate.
      • NOTE: This also implements a minor improvement w.r.t. to the original algorithm: We can split up the update into smaller chunks whenever a weight matrix is too large. This makes it so the maximum additional VRAM used during a mezo update can be selected apriori, instead of being the size of the largest weight (e.g. the embedding matrix in LLMs).
    • mezo.reconstruct_updates: Reconstructs a sequence of MeZO updates given the model, random seeds and projected gradients of each step
    • mezo.average_of_updates: Performs the average of multiple MeZO updates given the model, random seeds and projected gradients of each step or worker
    • mezo.distributed_update: Distributed update, each worker communicates the projected grads (and random seed implicitly) to all other workers. Each worker ends up reconstructing the average update from all workers.
  • Add a distributed MeZO update in mezo.distributed
  • Make this installation pip-installable and add small install instructions in the README.md
  • Add unit tests for every added major function (mezo.update, mezo.reconstruct_updates, mezo.average_of_updates, mezo.distributed_update)

If you'd like to add these changes to your repo, could you please just make sure that I didn't miss anything in my re-implementation of the algorithm (perhaps by reading through the mezo.update and mezo.distributed_update functions, if possible).

Thanks and congratulations on this great work!

lebrice avatar Dec 20 '23 21:12 lebrice

Hey @gaotianyu1350 @sadhikamalladi @eltociear @danqi, would you be interested in reviewing this contribution to your repo?

lebrice avatar Mar 01 '24 17:03 lebrice

@lebrice appreciate the effort youve gone through, this is adding productively to some experiments were o at the moment!

Alignment-Lab-AI avatar Mar 27 '24 17:03 Alignment-Lab-AI

@gaotianyu1350 @sadhikamalladi @eltociear @danqi

lebrice avatar Jun 03 '24 17:06 lebrice