pyvkfft icon indicating copy to clipboard operation
pyvkfft copied to clipboard

Convolution support

Open vincefn opened this issue 1 year ago • 8 comments

Hi @Dtolm , now that the release is out, I made some tests with on-the-fly convolution following Osamu's email exchange which picked my curiosity.

There is now a branch with convolution support

What I have seen (but I made so far a very limited number of tests):

  • 2D and batched 2D convolution works for R2C and C2C (compared to convolution using numpy), both for out- and in-place transforms
  • for batched 2D, I use config->coordinateFeatures = n_batch instead of numberBatches, I assume that's the proper way
  • 1D transforms kernel compilations fail (see the notebook below)
  • 3D transforms give incorrect results

The tests are all visible on the pyvkfft-convolve notebook.

vincefn avatar Feb 07 '24 13:02 vincefn

Hello

The convolutions code design is 3 years old, I mostly did the things needed for my Master's thesis with them then - matrix vector convolutions for multidimensional systems. So not all things have been fully implemented - R2C convolutions breaking for 1D is expected as I didn't know how to combine all the things (R2C decomposition and convolution) in one kernel then. It will be easier to implement now with the new modular structure.

The 3D convolutions not working seems to be a bug that I think I have fixed on the dev branch (the modified test 51 passes now).

The numberBatches should work the same as coordinateFeatures now as well (unless you use the matrix-vector functionality).

DTolm avatar Feb 08 '24 15:02 DTolm

Thanks ! I am almost finished with the release (needed to update the conda packages), so I can look at this. The 3D convolutions work nicely, thanks !

I was wondering (since I do not yet completely understand how the coordinateFeatures work), if it is possible to perform the following: use an array of shape n_batch*ny*nx, and perform a 2D convolution with a single array of size ny*nx ? Can numberBatches and coordinateFeatures be configured to perform that ?

It can be very useful to compute the cross-correlation of N images vs a single reference, or for near-field propagation (this is practically a 2D convolution of stacks of array, with the same kernel).

vincefn avatar Feb 17 '24 16:02 vincefn

coordinateFeatures should behave the same as numberBatches, unless you do matrix-vector multiplication convolutions (kernel is a matrix, system is a vector). It is a second form of batching (both work at the same time), since omitDimension is not working with convolutions right now.

As for multiple input - single kernel convolutions, I have only implemented reverse so far (one system, multiple kernels, multiple outputs). I can make it work for this case as well, as using the reverse option should have worse performance.

DTolm avatar Feb 21 '24 09:02 DTolm

As for multiple input - single kernel convolutions, I have only implemented reverse so far (one system, multiple kernels, multiple outputs). I can make it work for this case as well, as using the reverse option should have worse performance.

Yes, this would be very interesting e.g. for multiple (batched) images alignment vs a single reference: 1 batch of 2D images and a single reference image

vincefn avatar Feb 21 '24 12:02 vincefn

@vincefn I have added an option to do this functionality and an example that shows how to set it up (53). It didn't require any big changes, so, hopefully, it will work straight away.

DTolm avatar Mar 02 '24 23:03 DTolm

Hi @DTolm, I've now updated the code to support various types of batch transforms, e.g. an array shape of (nbatch, ny, nx) with the same kernel shape, or a smaller kernel shape e.g. (ny, nx) using singleKernelMultipleBatches or even (nbatchk, ny, nx) using both singleKernelMultipleBatches and coordinatefeatures as long as nbatch is a multiple of nbatchk. Nice !

Now I have found issues with some odd transforms, e.g. (on my mac) a 1D+convolution transform of shape 3*7 fails with a compilation error (same for many odd transforms with small primes). It's true for a C2C but also R2C, regardless of other parameters (in/out, batched or not,..).

Also, in-place R2C transforms work for 2D and 3D, but not out-of-place. I guess it's not be possible since the complex array needs 1 or 2 extra bytes, and I don't think there's an easy way around this. (incidentally this is related to the discussion in https://github.com/DTolm/VkFFT/issues/159)

vincefn avatar Mar 10 '24 11:03 vincefn

Hi @DTolm,

I've updated the code so I can use batch convolution also for cuda, and also the systematic command-line test can be used for convolution.

I've also clarified the systems which work (c2c, inplace r2c ndim>1, radix, single upload only).

Here's an example of test between 2 and 128 for c2c out-of-place where we can see which radix sizes are failing (here using cuda on an A4500) - always compilation errors, some missing }. (there are odd and even sizes contrary to my previous message):

  pycuda C2C⨂           (2,2) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.6e-08 ninf=9.6e-08 < 2.3e-06 (0.042) 1 buf=    0   OK  
  pycuda C2C⨂           (3,3) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.7e-07 ninf=1.6e-07 < 2.5e-06 (0.066) 1 buf=    0   OK  
  pycuda C2C⨂           (4,4) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.6e-07 ninf=1.4e-07 < 2.6e-06 (0.055) 1 buf=    0   OK  
  pycuda C2C⨂           (5,5) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.1e-07 ninf=2.6e-07 < 2.7e-06 (0.095) 1 buf=    0   OK  
  pycuda C2C⨂           (6,6) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.7e-07 ninf=1.4e-07 < 2.8e-06 (0.049) 1 buf=    0   OK  
  pycuda C2C⨂           (7,7) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.9e-07 ninf=2.2e-07 < 2.8e-06 (0.078) 1 buf=    0   OK  
  pycuda C2C⨂           (8,8) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.6e-07 ninf=1.7e-07 < 2.9e-06 (0.060) 1 buf=    0   OK  
  pycuda C2C⨂           (9,9) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.9e-07 ninf=2.3e-07 < 3.0e-06 (0.079) 1 buf=    0   OK  
  pycuda C2C⨂         (10,10) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.3e-07 ninf=2.4e-07 < 3.0e-06 (0.080) 1 buf=    0   OK  
  pycuda C2C⨂         (11,11) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.2e-07 ninf=2.2e-07 < 3.0e-06 (0.071) 1 buf=    0   OK  
  pycuda C2C⨂         (12,12) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.1e-07 ninf=2.4e-07 < 3.1e-06 (0.079) 1 buf=    0   OK  
  pycuda C2C⨂         (13,13) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.0e-07 ninf=2.6e-07 < 3.1e-06 (0.084) 1 buf=    0   OK  
  pycuda C2C⨂         (14,14) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.8e-07 ninf=2.2e-07 < 3.1e-06 (0.071) 1 buf=    0   OK  
  pycuda C2C⨂         (15,15) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.3e-07 ninf=2.7e-07 < 3.2e-06 (0.085) 1 buf=    0   OK  
  pycuda C2C⨂         (16,16) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.6e-07 ninf=3.3e-07 < 3.2e-06 (0.102) 1 buf=    0   OK  
  pycuda C2C⨂         (18,18) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.9e-07 ninf=4.9e-07 < 3.3e-06 (0.150) 1 buf=    0   OK  
  pycuda C2C⨂         (20,20) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.9e-07 ninf=6.0e-07 < 3.3e-06 (0.182) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(21, 21), primes='3×7', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(22, 22), primes='2×11', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (24,24) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=3.4e-07 ninf=3.4e-07 < 3.4e-06 (0.102) 1 buf=    0   OK  
  pycuda C2C⨂         (25,25) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.3e-07 ninf=8.9e-07 < 3.4e-06 (0.263) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(26, 26), primes='2×13', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (27,27) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=6.1e-07 < 3.4e-06 (0.178) 1 buf=    0   OK  
  pycuda C2C⨂         (28,28) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=3.0e-07 ninf=3.8e-07 < 3.4e-06 (0.111) 1 buf=    0   OK  
  pycuda C2C⨂         (30,30) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.5e-07 ninf=6.4e-07 < 3.5e-06 (0.185) 1 buf=    0   OK  
  pycuda C2C⨂         (32,32) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=3.4e-07 ninf=3.5e-07 < 3.5e-06 (0.100) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(33, 33), primes='3×11', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (35,35) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.2e-07 ninf=5.7e-07 < 3.5e-06 (0.162) 1 buf=    0   OK  
  pycuda C2C⨂         (36,36) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.9e-07 ninf=5.0e-07 < 3.6e-06 (0.142) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(39, 39), primes='3×13', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (40,40) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=5.4e-07 < 3.6e-06 (0.149) 1 buf=    0   OK  
  pycuda C2C⨂         (42,42) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.6e-07 ninf=4.8e-07 < 3.6e-06 (0.133) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(44, 44), primes='2²×11', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (45,45) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.6e-07 ninf=5.3e-07 < 3.7e-06 (0.146) 1 buf=    0   OK  
  pycuda C2C⨂         (48,48) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.9e-07 ninf=4.9e-07 < 3.7e-06 (0.134) 1 buf=    0   OK  
  pycuda C2C⨂         (49,49) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=6.1e-07 < 3.7e-06 (0.165) 1 buf=    0   OK  
  pycuda C2C⨂         (50,50) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.9e-07 ninf=7.9e-07 < 3.7e-06 (0.213) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(52, 52), primes='2²×13', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (54,54) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.1e-07 ninf=8.2e-07 < 3.7e-06 (0.221) 1 buf=    0   OK  
  pycuda C2C⨂         (55,55) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.2e-07 ninf=6.7e-07 < 3.7e-06 (0.180) 1 buf=    0   OK  
  pycuda C2C⨂         (56,56) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.8e-07 ninf=5.3e-07 < 3.7e-06 (0.142) 1 buf=    0   OK  
  pycuda C2C⨂         (60,60) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.0e-07 ninf=6.8e-07 < 3.8e-06 (0.180) 1 buf=    0   OK  
  pycuda C2C⨂         (63,63) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.6e-07 ninf=7.0e-07 < 3.8e-06 (0.184) 1 buf=    0   OK  
  pycuda C2C⨂         (64,64) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.0e-07 ninf=5.6e-07 < 3.8e-06 (0.147) 1 buf=    0   OK  
  pycuda C2C⨂         (65,65) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.5e-07 ninf=7.7e-07 < 3.8e-06 (0.202) 1 buf=    0   OK  
  pycuda C2C⨂         (66,66) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=6.4e-07 < 3.8e-06 (0.167) 1 buf=    0   OK  
  pycuda C2C⨂         (70,70) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.8e-07 ninf=6.1e-07 < 3.8e-06 (0.159) 1 buf=    0   OK  
  pycuda C2C⨂         (72,72) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.0e-07 ninf=6.6e-07 < 3.9e-06 (0.170) 1 buf=    0   OK  
  pycuda C2C⨂         (75,75) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.8e-07 ninf=6.9e-07 < 3.9e-06 (0.178) 1 buf=    0   OK  
  pycuda C2C⨂         (77,77) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.6e-07 ninf=6.8e-07 < 3.9e-06 (0.174) 1 buf=    0   OK  
  pycuda C2C⨂         (78,78) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.9e-07 ninf=7.0e-07 < 3.9e-06 (0.180) 1 buf=    0   OK  
  pycuda C2C⨂         (80,80) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.9e-07 ninf=8.3e-07 < 3.9e-06 (0.212) 1 buf=    0   OK  
  pycuda C2C⨂         (81,81) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.7e-07 ninf=7.8e-07 < 3.9e-06 (0.198) 1 buf=    0   OK  
  pycuda C2C⨂         (84,84) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=7.2e-07 < 3.9e-06 (0.184) 1 buf=    0   OK  
  pycuda C2C⨂         (88,88) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.6e-07 ninf=5.5e-07 < 3.9e-06 (0.139) 1 buf=    0   OK  
  pycuda C2C⨂         (90,90) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=9.0e-07 ninf=9.0e-07 < 4.0e-06 (0.229) 1 buf=    0   OK  
  pycuda C2C⨂         (91,91) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=6.6e-07 < 4.0e-06 (0.167) 1 buf=    0   OK  
  pycuda C2C⨂         (96,96) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.9e-07 ninf=7.8e-07 < 4.0e-06 (0.195) 1 buf=    0   OK  
  pycuda C2C⨂         (98,98) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.7e-07 ninf=7.7e-07 < 4.0e-06 (0.193) 1 buf=    0   OK  
  pycuda C2C⨂         (99,99) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.9e-07 ninf=6.4e-07 < 4.0e-06 (0.161) 1 buf=    0   OK  
  pycuda C2C⨂       (100,100) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.1e-07 ninf=7.5e-07 < 4.0e-06 (0.186) 1 buf=    0   OK  
  pycuda C2C⨂       (104,104) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.7e-07 ninf=5.4e-07 < 4.0e-06 (0.134) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(105, 105), primes='3×5×7', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂       (108,108) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.3e-07 ninf=8.2e-07 < 4.0e-06 (0.203) 1 buf=    0   OK  
  pycuda C2C⨂       (110,110) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.3e-07 ninf=7.8e-07 < 4.0e-06 (0.193) 1 buf=    0   OK  
  pycuda C2C⨂       (112,112) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.0e-07 ninf=5.7e-07 < 4.0e-06 (0.140) 1 buf=    0   OK  
  pycuda C2C⨂       (117,117) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.2e-07 ninf=6.6e-07 < 4.1e-06 (0.162) 1 buf=    0   OK  
  pycuda C2C⨂       (120,120) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=6.4e-07 < 4.1e-06 (0.157) 1 buf=    0   OK  
  pycuda C2C⨂       (121,121) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=6.3e-07 < 4.1e-06 (0.155) 1 buf=    0   OK  
  pycuda C2C⨂       (125,125) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.1e-06 ninf=1.1e-06 < 4.1e-06 (0.262) 1 buf=    0   OK  
  pycuda C2C⨂       (126,126) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.3e-07 ninf=7.7e-07 < 4.1e-06 (0.188) 1 buf=    0   OK  
  pycuda C2C⨂       (128,128) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.8e-07 ninf=8.5e-07 < 4.1e-06 (0.207) 1 buf=    0   OK  

vincefn avatar Mar 14 '24 14:03 vincefn

Hello,

Sorry for the long reply, I am currently busy with another project. I will investigate the systems failing in the near future. Thank you for reporting them.

Best regards, Dmitrii

DTolm avatar Mar 22 '24 14:03 DTolm