How difficult would adapting FlashFFTConv for Conv2d with larger filter sizes be?
First off, really cool work especially in terms of maximizing model performance on GPUs. I wanted to ask (as an extension of #22) how difficult it would be to implement/adapt FlasFFTConv for doing 2D convolution on large images (4096x4096) with similarly large filters (between 256x256 and 512x512).
My specific application here is comparing large cryo-EM images with projections of reference template using FFT-based cross-correlation. A basic package is located here with the main function which gets called over and over at src/torch_2dtm/cross_correlate.py Essentially, the method boils down to:
- Iterate over a large number of different projections (N ~1.6 million)
- Performing cross-correlation of the projection (with zero-padding) against an image
- Doing some statistics updates and tracking
The FFT, conjugate multiply, and IFFT (step 2) account for ~80% of the computation time, so any improvements in the core cross-correlate step would lead to major speed gains.
My reasoning behind opening a new GitHub issue breaks down into three parts:
- Would you expect to see similar speedups (~8x) for larger filters in 2D for this use case?
- Do you think the tensor cores could still effectively be utilized?
- What challenges/barriers would you envision for implementing a Conv2d operation in FlashFFTConv?
Any insight and or feedback on this use-case would be greatly appreciated!
I would be interested in this too. I have a 4096x4096 image and a kernel of size 4096. I would love to code and learn this from scratch in CUDA!