mlx
mlx copied to clipboard
Support for pooling layers for CNNs, e.g. MaxPool2d
Thanks to the mlx team for creating and sharing mlx.
I have managed to get a small CIFAR-10 image classification CNN up and running rather quickly in mlx (inspired by the PyTorch CIFAR-10 tutorial). I have found that pooling layers (e.g. MaxPool2D) are not available yet. I hope that they will be available in the next release(s).
Code is here: https://github.com/menzHSE/mlx-cifar-10-cnn Heavily borrows from the mnist example in mlx (https://github.com/ml-explore/mlx-examples/tree/main/mnist)
Don't know if it is going to be the next release but we probably will implement fast pooling and/or upsampling operations.
In the meantime you can do some of those using array operations like the upsampling 2d in the stable diffusion example:
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
or use strided conv or reshapes to do pooling operations (haven't tested the code below but it should work):
def avg_pool_2d(x, stride: int = 2):
B, W, H, C = x.shape
x = x.reshape(B, W//stride, stride, H//stride, stride, C).mean((2, 4))
return x
I have implemented some CNN layers here: https://github.com/robertmccraith/mimm More models to be implemented over time, feel free to contribute