mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Example of a Convolutional Variational Autoencoder (CVAE) on MNIST

Open menzHSE opened this issue 1 year ago • 5 comments

This PR introduces a (small) Convolutional Variational Autoencoder (CVAE) example on MNIST. Not sure whether this is interesting for mlx-examples or not. It is a port from https://github.com/menzHSE/torch-vae to mlx (limited to MNIST for now).

The example includes model training, reconstruction of training / test images, and generating novel image samples. A small pre-trained model is included that allows to reconstruct / generate without training.

It is tested with mlx@026ef9a but should probably wait until https://github.com/ml-explore/mlx/pull/385 (remove retain_graph flag) appears in an official release. I am assuming, this will be in mlx 0.0.8.

Checklist Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [NA - example added] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

TODO

  • [ ] Test with next official mlx release and verify mlx>=0.0.8 in requirements.txt is still valid
  • [x] Look into model loading with strict=True

menzHSE avatar Jan 09 '24 10:01 menzHSE

Edit: Fixed in https://github.com/ml-explore/mlx/pull/409

There is one existing issue that I am not sure how to handle. When loading the model with strict=True I am getting a shape error for the BatchNorm layers. Loading with strict=False seems to load everything correctly, as far as I could check.

With load_weights(fname, strict=True):

$ python generate.py --model=pretrained/vae_mnist_filters_0064_
dims_0008.npz  --latent_dims=8 --outfile=samples.png --seed=0
Traceback (most recent call last):
  File "/Users/menzHSE/Development/mlx-examples/cvae/generate.py", line 88, in <module>
    generate(
  File "/Users/menzHSE/Development/mlx-examples/cvae/generate.py", line 19, in generate
    vae.load(model_fname)
  File "/Users/menzHSE/Development/mlx-examples/cvae/model.py", line 235, in load
    self.load_weights(fname, strict=True)
  File "/Users/menzHSE/Development/miniforge3/envs/mlx-pytorch-2024-01/lib/python3.11/site-packages/mlx/nn/layers/base.py", line 158, in load_weights
    raise ValueError(
ValueError: Expected shape [16] but received  shape [1, 1, 1, 16] for parameter encoder.bn1.running_mean

menzHSE avatar Jan 09 '24 10:01 menzHSE

@menzHSE this is awesome! I think a VAE example is definitely in scope for this repo. Just as a warning we have a bit of a review backlog so it might be a little while before we can get to reviewing it and merging. But I would love to include this!

Also the BatchNorm thing looks like a bug... with saving and loading BN stats... I will take a look at that.

awni avatar Jan 09 '24 14:01 awni

Awesome! I'll keep working on it and will wait for the next mlx release anyway.

Regarding the batch norm issue: I tried but could not reproduce it with a minimal example. Maybe I am missing something in my model def.

menzHSE avatar Jan 09 '24 15:01 menzHSE

The Batch Norm issue should be fixed in https://github.com/ml-explore/mlx/pull/409

awni avatar Jan 09 '24 15:01 awni

Verified that https://github.com/ml-explore/mlx/pull/409 fixed the loading in strict mode with mlx@1d90a76

menzHSE avatar Jan 10 '24 09:01 menzHSE

  • Tested with mlx 0.0.9
  • Requirements and documentation updated
  • Style checked
  • Rebased and forced-pushed

I think it is ready for review.

menzHSE avatar Jan 12 '24 12:01 menzHSE

Nice work. I pulled your CVAE branch @menzHSE for the PR and tested with latest mlx-core and mlx-datasets. Tested on MBP M2Pro 32gb

...
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 32.0KiB (53.1MiB/s)
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz |   Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 8.0KiB (13.2MiB/s)
Number of trainable params: 0.1493 M
Epoch    0 | Loss   19312.91 | Throughput   947.85 im/s | Time     64.4 (s)
Epoch    1 | Loss   11821.64 | Throughput   942.04 im/s | Time     64.7 (s)
Epoch    2 | Loss   10845.00 | Throughput   941.41 im/s | Time     64.8 (s)
Epoch    3 | Loss   10316.70 | Throughput   941.30 im/s | Time     64.8 (s)
Epoch    4 | Loss    9976.91 | Throughput   941.35 im/s | Time     64.8 (s)
Epoch    5 | Loss    9739.50 | Throughput   941.35 im/s | Time     64.8 (s)
Epoch    6 | Loss    9564.17 | Throughput   939.39 im/s | Time     64.9 (s)
Epoch    7 | Loss    9425.80 | Throughput   935.16 im/s | Time     65.2 (s)
Epoch    8 | Loss    9312.95 | Throughput   939.39 im/s | Time     64.9 (s)
python3 generate.py --model=pretrained/vae_mnist_filters_0064_dims_0008.npz  --latent_dims=8 --outfile=samples.png --seed=0
Loaded model with 8 latent dims from pretrained/vae_mnist_filters_0064_dims_0008.npz
Saved 128 generated images to samples.png

bigsnarfdude avatar Jan 16 '24 19:01 bigsnarfdude

Tested with mlx-0.0.11, rebased and force-pushed

menzHSE avatar Jan 26 '24 11:01 menzHSE

I left a couple of high level questions, could you check?

I will do a more detailed review after hearing from you. FYI usually with the examples reivewing for nits I just push directly to the branch and if there are bigger changes needed I will leave comments.

awni avatar Jan 26 '24 17:01 awni

@menzHSE is this ready for review?

awni avatar Feb 02 '24 03:02 awni

@awni Yes. I have made the minor changes requested above and it's ready for review.

menzHSE avatar Feb 02 '24 06:02 menzHSE

@menzHSE this is really nice!

I took a deeper look today and simplifies / reorganized the example to be more self-contained and in the style of our other small examples.

I want to update the README after finishing a full training run, but other than that I think we can merge this shortly. Let me know if you have any comments!

awni avatar Feb 06 '24 17:02 awni

Thanks @awni. I have also started a training run.

menzHSE avatar Feb 06 '24 17:02 menzHSE

@menzHSE you might be interested to check the updated README log. I notice the throughput I get on my M1 Max is much much better than the numbers you had before. I'm wondering if you're pro is that much slower of if MLX has also gotten faster for you?

At any rate, I think this is good to merge. Thanks so much for the contribution, it's really nicely done!!

awni avatar Feb 07 '24 03:02 awni

@awni Great! Thanks for merging. I am getting a throughput of approx. 485 im/s on my M1 which is quite a bit faster than 410 im/s with previous mlx versions. Still a huge gap to your M1 Max ... time for an upgrade :)

menzHSE avatar Feb 07 '24 06:02 menzHSE