mlx-examples
mlx-examples copied to clipboard
Documentation request: saving a model and loading after training
Could you please add to the documentation what is the way to save models when using MLX - after training is complete final model for inferencing.
Perhaps even add sample code to one of the mlx_examples e.g transformer_lm. How to save for checkpoints would be useful too.
I see multiple methods to save (e.g mx.savez(), model.save_weights()) and unclear whats the best way that saves all the required state and the corresponding methods to load it back from disk.
Maybe there isn't an official document yet, but from my understanding, the original mlx is using savez. Now it has added support for save_safeTenser, so there are some inconsistencies due to the rapid development of the framework. However, if you take a look at the lora example, it should have the most up-to-date method on how to save/load model weights. I agree that having some official documentation on how to do it would be great.
EXAMPLE https://github.com/ml-explore/mlx-examples/tree/main/lora
SAVING if you look at this file and grok the usage https://github.com/ml-explore/mlx-examples/blame/main/lora/lora.py#L327 you see it being used in training loop. docs here: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.savez.html
LOADING loading weights from npz format usage is https://github.com/ml-explore/mlx-examples/blame/main/lora/lora.py#L335 and docs https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.Module.load_weights.html?highlight=load+weights
@bigsnarfdude - would this work for checkpointing as well, or do i need to save additional data for checkpointing?
@sandeepimpressico looks like framework has got new code for checkpoint. just the npz file is all that is needed for checkpoints save and weights_load. here is the code:
https://github.com/ml-explore/mlx-examples/commit/d8680a89f986492dbc27c36af3294034db26458f