mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Group support for `conv3d`

Open mlaves opened this issue 1 year ago • 3 comments

Proposed changes

Added group support to conv3d forward pass (backward pass still missing, as for conv2d). I adapted slow_conv_3D, explicit_gemm_conv_ND_cpu, and conv_3D_gpu to support groups > 1 and added tests.

Checklist

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

mlaves avatar Jul 11 '24 20:07 mlaves

Looks good to me. Do you mind sharing the results of a benchmark to see how well the groups parameter works for 3D convs? Ideally we should get a nice speed up when using a large number of groups.

awni avatar Jul 15 '24 22:07 awni

Looks good to me. Do you mind sharing the results of a benchmark to see how well the groups parameter works for 3D convs? Ideally we should get a nice speed up when using a large number of groups.

Thanks, I'll add some benchmarks!

mlaves avatar Jul 16 '24 18:07 mlaves

@mlaves what's the status of this PR? Should we try and land it soon?

awni avatar Oct 14 '24 16:10 awni