mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Documentation request: saving a model and loading after training

Open sandeepimpressico opened this issue 1 year ago • 4 comments

Could you please add to the documentation what is the way to save models when using MLX - after training is complete final model for inferencing.

Perhaps even add sample code to one of the mlx_examples e.g transformer_lm. How to save for checkpoints would be useful too.

I see multiple methods to save (e.g mx.savez(), model.save_weights()) and unclear whats the best way that saves all the required state and the corresponding methods to load it back from disk.

sandeepimpressico avatar Jan 13 '24 16:01 sandeepimpressico

Maybe there isn't an official document yet, but from my understanding, the original mlx is using savez. Now it has added support for save_safeTenser, so there are some inconsistencies due to the rapid development of the framework. However, if you take a look at the lora example, it should have the most up-to-date method on how to save/load model weights. I agree that having some official documentation on how to do it would be great.

mzbac avatar Jan 14 '24 03:01 mzbac

EXAMPLE https://github.com/ml-explore/mlx-examples/tree/main/lora

SAVING if you look at this file and grok the usage https://github.com/ml-explore/mlx-examples/blame/main/lora/lora.py#L327 you see it being used in training loop. docs here: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.savez.html

LOADING loading weights from npz format usage is https://github.com/ml-explore/mlx-examples/blame/main/lora/lora.py#L335 and docs https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.Module.load_weights.html?highlight=load+weights

bigsnarfdude avatar Jan 15 '24 01:01 bigsnarfdude

@bigsnarfdude - would this work for checkpointing as well, or do i need to save additional data for checkpointing?

sandeepimpressico avatar Jan 17 '24 01:01 sandeepimpressico

@sandeepimpressico looks like framework has got new code for checkpoint. just the npz file is all that is needed for checkpoints save and weights_load. here is the code:

https://github.com/ml-explore/mlx-examples/commit/d8680a89f986492dbc27c36af3294034db26458f

bigsnarfdude avatar Jan 17 '24 18:01 bigsnarfdude