mlx
mlx copied to clipboard
Add groups to 2-D convolutions
Proposed changes
Added groups to 2-D convolutions for some kernel specializations.
Also fixed 1D grouped convs with different kernel strides and added more tests.
Can close out #100
Performance looks pretty good:
(N, H, W, C) | (O, kH, kW, C) | dtype | stride | pads | groups | diff% |
---|---|---|---|---|---|---|
(4, 64, 64, 256) | (256, 5, 5, 256) | float32 | (1, 1) | (2, 2) | 1 | +25.78% |
(4, 64, 64, 256) | (256, 5, 5, 256) | float32 | (1, 1) | (2, 2) | 2 | +36.72% |
(4, 64, 64, 256) | (256, 5, 5, 256) | float32 | (1, 1) | (2, 2) | 16 | -23.80% |
(4, 64, 64, 256) | (256, 5, 5, 256) | float32 | (1, 1) | (2, 2) | 64 | +92.72% |
Checklist
Put an x
in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)