mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Support for pooling layers for CNNs, e.g. MaxPool2d

Open menzHSE opened this issue 1 year ago • 2 comments

Thanks to the mlx team for creating and sharing mlx.

I have managed to get a small CIFAR-10 image classification CNN up and running rather quickly in mlx (inspired by the PyTorch CIFAR-10 tutorial). I have found that pooling layers (e.g. MaxPool2D) are not available yet. I hope that they will be available in the next release(s).

Code is here: https://github.com/menzHSE/mlx-cifar-10-cnn Heavily borrows from the mnist example in mlx (https://github.com/ml-explore/mlx-examples/tree/main/mnist)

menzHSE avatar Dec 06 '23 15:12 menzHSE

Don't know if it is going to be the next release but we probably will implement fast pooling and/or upsampling operations.

In the meantime you can do some of those using array operations like the upsampling 2d in the stable diffusion example:

def upsample_nearest(x, scale: int = 2):
    B, H, W, C = x.shape
    x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
    x = x.reshape(B, H * scale, W * scale, C)

    return x

or use strided conv or reshapes to do pooling operations (haven't tested the code below but it should work):

def avg_pool_2d(x, stride: int = 2):
    B, W, H, C = x.shape
    x = x.reshape(B, W//stride, stride, H//stride, stride, C).mean((2, 4))
    return x

angeloskath avatar Dec 06 '23 22:12 angeloskath

I have implemented some CNN layers here: https://github.com/robertmccraith/mimm More models to be implemented over time, feel free to contribute

robertmccraith avatar Dec 12 '23 21:12 robertmccraith