mlx-examples
mlx-examples copied to clipboard
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST
This PR introduces a (small) Convolutional Variational Autoencoder (CVAE) example on MNIST. Not sure whether this is interesting for mlx-examples or not. It is a port from https://github.com/menzHSE/torch-vae to mlx (limited to MNIST for now).
The example includes model training, reconstruction of training / test images, and generating novel image samples. A small pre-trained model is included that allows to reconstruct / generate without training.
It is tested with mlx@026ef9a but should probably wait until https://github.com/ml-explore/mlx/pull/385 (remove retain_graph
flag) appears in an official release. I am assuming, this will be in mlx 0.0.8.
Checklist Put an x in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
- [NA - example added] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
TODO
- [ ] Test with next official mlx release and verify
mlx>=0.0.8
inrequirements.txt
is still valid - [x] Look into model loading with
strict=True
Edit: Fixed in https://github.com/ml-explore/mlx/pull/409
There is one existing issue that I am not sure how to handle. When loading the model with strict=True
I am getting a shape error for the BatchNorm layers. Loading with strict=False
seems to load everything correctly, as far as I could check.
With load_weights(fname, strict=True)
:
$ python generate.py --model=pretrained/vae_mnist_filters_0064_
dims_0008.npz --latent_dims=8 --outfile=samples.png --seed=0
Traceback (most recent call last):
File "/Users/menzHSE/Development/mlx-examples/cvae/generate.py", line 88, in <module>
generate(
File "/Users/menzHSE/Development/mlx-examples/cvae/generate.py", line 19, in generate
vae.load(model_fname)
File "/Users/menzHSE/Development/mlx-examples/cvae/model.py", line 235, in load
self.load_weights(fname, strict=True)
File "/Users/menzHSE/Development/miniforge3/envs/mlx-pytorch-2024-01/lib/python3.11/site-packages/mlx/nn/layers/base.py", line 158, in load_weights
raise ValueError(
ValueError: Expected shape [16] but received shape [1, 1, 1, 16] for parameter encoder.bn1.running_mean
@menzHSE this is awesome! I think a VAE example is definitely in scope for this repo. Just as a warning we have a bit of a review backlog so it might be a little while before we can get to reviewing it and merging. But I would love to include this!
Also the BatchNorm thing looks like a bug... with saving and loading BN stats... I will take a look at that.
Awesome! I'll keep working on it and will wait for the next mlx release anyway.
Regarding the batch norm issue: I tried but could not reproduce it with a minimal example. Maybe I am missing something in my model def.
The Batch Norm issue should be fixed in https://github.com/ml-explore/mlx/pull/409
Verified that https://github.com/ml-explore/mlx/pull/409 fixed the loading in strict mode with mlx@1d90a76
- Tested with mlx 0.0.9
- Requirements and documentation updated
- Style checked
- Rebased and forced-pushed
I think it is ready for review.
Nice work. I pulled your CVAE branch @menzHSE for the PR and tested with latest mlx-core and mlx-datasets. Tested on MBP M2Pro 32gb
...
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 32.0KiB (53.1MiB/s)
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 8.0KiB (13.2MiB/s)
Number of trainable params: 0.1493 M
Epoch 0 | Loss 19312.91 | Throughput 947.85 im/s | Time 64.4 (s)
Epoch 1 | Loss 11821.64 | Throughput 942.04 im/s | Time 64.7 (s)
Epoch 2 | Loss 10845.00 | Throughput 941.41 im/s | Time 64.8 (s)
Epoch 3 | Loss 10316.70 | Throughput 941.30 im/s | Time 64.8 (s)
Epoch 4 | Loss 9976.91 | Throughput 941.35 im/s | Time 64.8 (s)
Epoch 5 | Loss 9739.50 | Throughput 941.35 im/s | Time 64.8 (s)
Epoch 6 | Loss 9564.17 | Throughput 939.39 im/s | Time 64.9 (s)
Epoch 7 | Loss 9425.80 | Throughput 935.16 im/s | Time 65.2 (s)
Epoch 8 | Loss 9312.95 | Throughput 939.39 im/s | Time 64.9 (s)
python3 generate.py --model=pretrained/vae_mnist_filters_0064_dims_0008.npz --latent_dims=8 --outfile=samples.png --seed=0
Loaded model with 8 latent dims from pretrained/vae_mnist_filters_0064_dims_0008.npz
Saved 128 generated images to samples.png
Tested with mlx-0.0.11, rebased and force-pushed
I left a couple of high level questions, could you check?
I will do a more detailed review after hearing from you. FYI usually with the examples reivewing for nits I just push directly to the branch and if there are bigger changes needed I will leave comments.
@menzHSE is this ready for review?
@awni Yes. I have made the minor changes requested above and it's ready for review.
@menzHSE this is really nice!
I took a deeper look today and simplifies / reorganized the example to be more self-contained and in the style of our other small examples.
I want to update the README after finishing a full training run, but other than that I think we can merge this shortly. Let me know if you have any comments!
Thanks @awni. I have also started a training run.
@menzHSE you might be interested to check the updated README log. I notice the throughput I get on my M1 Max is much much better than the numbers you had before. I'm wondering if you're pro is that much slower of if MLX has also gotten faster for you?
At any rate, I think this is good to merge. Thanks so much for the contribution, it's really nicely done!!
@awni Great! Thanks for merging. I am getting a throughput of approx. 485 im/s on my M1 which is quite a bit faster than 410 im/s with previous mlx versions. Still a huge gap to your M1 Max ... time for an upgrade :)