MeZO
MeZO copied to clipboard
Add a pip-installable, simple implementation of MeZO (along with a distributed impl. and some tests)
Hello there!
I was very interested in your work after seeing it a NeurIPS. I'd like to play around with it a bit in the future. In order to do so, I felt it might be useful to add a simple, standalone implementation of your algorithm to your codebase, so people can more easily import it into their codebase and use it.
Here are my contributions, if you're interested:
- Add a simple, readable, standalone implementation of the MeZO update in a new
mezopackagemezo.update: Perform a single MeZO update given the model, loss function, inputs, random seed, epsilon and learning rate.- NOTE: This also implements a minor improvement w.r.t. to the original algorithm: We can split up the update into smaller chunks whenever a weight matrix is too large. This makes it so the maximum additional VRAM used during a mezo update can be selected apriori, instead of being the size of the largest weight (e.g. the embedding matrix in LLMs).
mezo.reconstruct_updates: Reconstructs a sequence of MeZO updates given the model, random seeds and projected gradients of each stepmezo.average_of_updates: Performs the average of multiple MeZO updates given the model, random seeds and projected gradients of each step or workermezo.distributed_update: Distributed update, each worker communicates the projected grads (and random seed implicitly) to all other workers. Each worker ends up reconstructing the average update from all workers.
- Add a distributed MeZO update in
mezo.distributed - Make this installation pip-installable and add small install instructions in the README.md
- Add unit tests for every added major function (
mezo.update,mezo.reconstruct_updates,mezo.average_of_updates,mezo.distributed_update)
If you'd like to add these changes to your repo, could you please just make sure that I didn't miss anything in my re-implementation of the algorithm (perhaps by reading through the mezo.update and mezo.distributed_update functions, if possible).
Thanks and congratulations on this great work!
Hey @gaotianyu1350 @sadhikamalladi @eltociear @danqi, would you be interested in reviewing this contribution to your repo?
@lebrice appreciate the effort youve gone through, this is adding productively to some experiments were o at the moment!
@gaotianyu1350 @sadhikamalladi @eltociear @danqi