mlx-examples
mlx-examples copied to clipboard
Updated CIFAR-10 ResNet example to use BatchNorm instead of LayerNorm
This PR updates the CIFAR-10 ResNet example to use BatchNorm instead of LayerNorm. It is tested with mlx@026ef9a but should probably wait until https://github.com/ml-explore/mlx/pull/385 (remove retain_graph
flag) appears in an official release. I am assuming, this will be in mlx 0.0.8.
Changes
- Replaced
nn.LayerNorm
bynn.BatchNorm
(test accuracy 0.807 after 100 epochs -> 0.833 after 30 epochs) - Decreased default number of epochs from 100 to 30 (no improvement observed after approx. 25 epochs)
- Updated code comments and
README.md
- Updated
requirements.txt
to requiremlx>=0.0.8
TODO
- [ ] Test with next official mlx release and verify
mlx>=0.0.8
inrequirements.txt
is still valid
I have noticed a significantly decreased throughput vs. the info in the original README. Both on a 16 MB M1 Macbook Pro. I am also seeing that without this PR applied (with current main
).
Previous
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
Epoch: 99 | Test acc 0.807
Current
Epoch: 29 | avg. Train loss 0.290 | avg. Train acc 0.899 | Throughput: 279.91 images/sec
Epoch: 29 | Test acc 0.833
@menzHSE did you record the "Previous" results on the same machine? In general it's expected that there can be considerable variability between types of hardware.
@awni I did not record the "Previous" results. It is taken from the original README.md
. I have tried to reproduce it with earlier mlx versions (down to 0.0.2) but am always getting similar throughput. So, whoever reported the initial results must have a better M1 MBPro than mine ;-) It is just coincidentally the same machine.
Well that is very strange... let me time it on my machine..
I am on main right now with a 32GB M1 Max. Clearly that seems to help a lot... I don't have access to a 16GB M1 pro, but I will also plan try on an M1 Mini and/or M2 Mini
Epoch 00 [020] | Train loss 2.215 | Train acc 0.223 | Throughput: 624.82 images/second
It migth be expected the Max is 2x the Pro so your numbers maybe pretty reasonable. But now I'd like to know where the original #s came from..
Updated requirements.txt to require mlx>=0.0.9. Rebased and force-pushed.
Tested with mlx 0.0.9. Ready to be merged.
Thanks @menzHSE !