flash-fft-conv icon indicating copy to clipboard operation
flash-fft-conv copied to clipboard

How difficult would adapting FlashFFTConv for Conv2d with larger filter sizes be?

Open mgiammar opened this issue 8 months ago • 1 comments

First off, really cool work especially in terms of maximizing model performance on GPUs. I wanted to ask (as an extension of #22) how difficult it would be to implement/adapt FlasFFTConv for doing 2D convolution on large images (4096x4096) with similarly large filters (between 256x256 and 512x512).

My specific application here is comparing large cryo-EM images with projections of reference template using FFT-based cross-correlation. A basic package is located here with the main function which gets called over and over at src/torch_2dtm/cross_correlate.py Essentially, the method boils down to:

  1. Iterate over a large number of different projections (N ~1.6 million)
  2. Performing cross-correlation of the projection (with zero-padding) against an image
  3. Doing some statistics updates and tracking

The FFT, conjugate multiply, and IFFT (step 2) account for ~80% of the computation time, so any improvements in the core cross-correlate step would lead to major speed gains.

My reasoning behind opening a new GitHub issue breaks down into three parts:

  1. Would you expect to see similar speedups (~8x) for larger filters in 2D for this use case?
  2. Do you think the tensor cores could still effectively be utilized?
  3. What challenges/barriers would you envision for implementing a Conv2d operation in FlashFFTConv?

Any insight and or feedback on this use-case would be greatly appreciated!

mgiammar avatar Mar 28 '25 18:03 mgiammar

I would be interested in this too. I have a 4096x4096 image and a kernel of size 4096. I would love to code and learn this from scratch in CUDA!

ayushsvas avatar Apr 01 '25 09:04 ayushsvas