mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Conv3d

Open mlaves opened this issue 1 year ago • 3 comments

Proposed changes

  • Implementation of naive slow_conv_3D on the cpu.
  • Implementation of explicit_gemm_conv_ND_cpu.
    • However, using this seems to be considerably slower than the naive implementation. I guess that materializing the strided input view takes very long. The actual gemm is quite fast. Therefore, this is currently unused in conv3d routing.
  • Usage of explicit_gemm_conv_ND_gpu for GPU implementation of conv3d.
  • Added tests for conv3d.
  • Fixed handling of negative padding. The old behavior led to incorrect shapes of the conv gradient w.r.t. the input in some specific kernel size/padding/dilation combinations.
  • I also fixed two typos in some error messages I spotted in the code.

Checklist

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

mlaves avatar Apr 13 '24 20:04 mlaves

Implementation of explicit_gemm_conv_ND_cpu. However, using this seems to be considerably slower than the naive implementation. I guess that materializing the strided input view takes very long. The actual gemm is quite fast. Therefore, this is currently unused in conv3d routing.

Should we delete it? If it's not fast and has no hope of being fast, I don't see much point to leaving it in.

awni avatar Apr 25 '24 03:04 awni

@mlaves are you planning to update this?

awni avatar May 03 '24 20:05 awni

@mlaves are you planning to update this?

Yes, I will update my PR soon.

mlaves avatar May 03 '24 21:05 mlaves

Implementation of explicit_gemm_conv_ND_cpu. However, using this seems to be considerably slower than the naive implementation. I guess that materializing the strided input view takes very long. The actual gemm is quite fast. Therefore, this is currently unused in conv3d routing.

Should we delete it? If it's not fast and has no hope of being fast, I don't see much point to leaving it in.

The gemm operation itself is fast, but the reshaping copy(in_strided_view, in_strided, CopyType::General) is slow. PyTorch's conv3d CPU implementation is 10x faster and also uses gemm-conv. Should I remove this now or is there any chance the reshaping will be faster in the future?

mlaves avatar May 07 '24 20:05 mlaves

Should I remove this now or is there any chance the reshaping will be faster in the future?

It's possible, we haven't looked at / optimized the CPU copies in a while.

awni avatar May 07 '24 20:05 awni

Should I remove this now or is there any chance the reshaping will be faster in the future?

It's possible, we haven't looked at / optimized the CPU copies in a while.

In this case, I would keep explicit_gemm_conv_ND_cpu, even though it's unused now.

mlaves avatar May 07 '24 20:05 mlaves

Sounds good to me!

awni avatar May 07 '24 20:05 awni

I incorporated all reviewer suggestions.

mlaves avatar May 07 '24 21:05 mlaves

I rebased onto main, CI should pass now.

mlaves avatar May 09 '24 21:05 mlaves