stumpy
stumpy copied to clipboard
Speeding up the computation of sliding dot product with fft
[Update] Here, I am providing some important links to external resources or the comments mentioned here:
- So far: 6-8step-FFT-based MASS, with Numba (SingleThreaded)
- A comment regarding: OFFT uses MIT License
- six-step / eight-step FFT
- OTFFT source code
- suggestion about doing profiling for different length
- A few Q and As : Also, there is an interesting question: "why is there a bump in performance for length 2^p, with p in
range(17, 21). - Can we implement GPU version of it?
- My initial post in Numba Discourse
- Q: What does real / imag part in FFT represent?
- implementation of irfft
- some comparisons regarding irfft
- Q: can we use vectorization to improve performance?
- use profila to profile Numba code
Currently, in Stumpy, the sliding dot product [of a query Q and a time series T], is computed via one of the two following functions:
core.sliding_dot_product, which takes advantage of fft trick usingscipy.signal.convolvecore._sliding_dot_product, which uses a njit on top ofnp.dot
The sliding dot product in MATALB (via fft trick) seems to be faster though.
# MATLAB code
%x is the data, y is the query
m = length(y);
n = length(x);
y = y(end:-1:1);%Reverse the query
y(m+1:n) = 0; %aappend zeros
%The main trick of getting dot products in O(n log n) time
X = fft(x);
Y = fft(y);
Z = X.*Y;
z = ifft(Z);
# and then use the slice `z(m:n)`
Can we get closer to the performance of MATLAB?
@NimaSarajpoor I just noticed that scipy.fft.rfft has a parameter called workers, which can perform FFT in parallel. I wonder if you could try that and see if it makes any difference?
@NimaSarajpoor I came across a more recent FFT implementation called OTFFT that claims to be faster than FFTW and has a more generous MIT license. However, I tried to implement the basic fft function in Python but haven't been able to get the same answer as scipy.fft.fft. Here's what I did (List-7: Final version of the Stockham Algorithm):
import math
import cmath
import numpy as np
def fft0(n, s, eo, x, y):
if not math.log2(n).is_integer(): # Check if n is power of 2
pass
m = n // 2
theta0 = 2 * math.pi / n
if n == 1:
if eo:
for q in range(s):
y[q] = x[q]
else:
for p in range(m):
wp = complex(math.cos(p*theta0), -math.sin(p*theta0))
for q in range(s):
a = complex(x[q + s*(p + 0)])
b = complex(x[q + s*(p + m)])
y[q + s*(2*p + 0)] = a + b
y[q + s*(2*p + 1)] = (a - b) * wp
fft0(n//2, 2*s, not eo, y, x)
def fft(n, x):
y = np.empty(n, dtype=complex)
fft0(n, 1, False, x, y)
for k in range(n):
x[k] /= n
Would you mind taking a look? Maybe I messed up somewhere but I've been staring at it for too long and I'm not able to spot anything. Thanks in advance!
@seanlaw
I came across a more recent FFT implementation called OTFFT that claims to be faster than FFTW and has a more generous MIT licene
Cool!
Would you mind taking a look?
Sure! Will take a look.
Also:
I have been trying scipy.fft.rfft / scipy.fft.fft. Also, as you mentioned before, I am using different number of workers, 1 vs os.cpu_count(). Haven't seen any improvement yet compared to stumpy.core.sliding_dot_product.
According to the scipy doc:
The workers argument specifies the maximum number of parallel jobs to split the FFT computation into. This will execute independent 1-D FFTs within x. So, x must be at least 2-D and the non-transformed axes must be large enough to split into chunks. If x is too small, fewer jobs may be used than requested.
I will test again and share the result and code for our future reference.
Also: I have been trying
scipy.fft.rfft/scipy.fft.fft. Also, as you mentioned before, I am using different number of workers,1vsos.cpu_count(). Haven't seen any improvement yet compared tostumpy.core.sliding_dot_product.According to the scipy doc:
The workers argument specifies the maximum number of parallel jobs to split the FFT computation into. This will execute independent 1-D FFTs within x. So, x must be at least 2-D and the non-transformed axes must be large enough to split into chunks. If x is too small, fewer jobs may be used than requested.
I will test again and share the result and code for our future reference.
Currently, core.sliding_dot_product is using the scipy.signal.convolve function. In what follows, the performance of code.sliding_dot_product is compared with some alternatives.
sliding_dot_product_v0 = core.sliding_dot_product
def sliding_dot_product_v1(Q, T):
n = len(T)
X = scipy.fft.rfft(T) * scipy.fft.rfft(np.flipud(Q), n=n)
out = scipy.fft.irfft(X, n=n)
return out[len(Q) - 1 :]
def sliding_dot_product_v2(Q, T):
n = len(T)
X = scipy.fft.rfft(T, workers=8) * scipy.fft.rfft(np.flipud(Q), n=n, workers=8)
out = scipy.fft.irfft(X, n=n, workers=8)
return out[len(Q) - 1 :]
def sliding_dot_product_v3(Q, T):
n = len(T)
X = np.fft.rfft(T) * np.fft.rfft(np.flipud(Q), n=n)
out = np.fft.irfft(X, n=n)
return out[len(Q) - 1 :]
And, this is the code for tracking the running time for different window sizes
n = 1_000_000
data = np.array(loadmat('./DAMP_data/mit_long_term_ecg14046.mat')['mit_long_term_ecg_14046'][0]).astype(np.float64)
T = data[:n]
t = []
for m in range(3, 5000):
Q = T[:m]
t1 = time.time()
comp = sliding_dot_product_function(Q, T)
t2 = time.time()
t.append(t2 - t1)
As observed:
-
The functions
sliding_dot_product_v1andsliding_dot_product_v2are both usingscipy rfftand they are the same except for the number of workers. As expected, their performances are close to each other. This is because the number of workers affects the performance if we have 2D inputs. -
The performance of
sliding_dot_product_v3(usingnumpy rfft) is close to the existing versionsliding_dot_product_v0.
Thanks @NimaSarajpoor. In case it matters (and if you're not already doing this), it would make sense to test window sizes and/or time series lengths in powers of 2 rather than increments of 1.
Thanks @NimaSarajpoor. In case it matters (and if you're not already doing this), it would make sense to test window sizes and/or time series lengths in powers of
2rather than increments of1.
[Note]
According to the source code of scipy.fft.rfft, we can use the parameter n to pad the input Q with zeros to make its length the same as the length of T. so, if T is power of two, I think we do not need to have power of two for the length of query. I am going to provide the performance of sliding dot product functions for queries with length in range (4, 1025):
[update] Correction regarding the label of x axis in the bottom figure is: "the length of query "
@seanlaw
@NimaSarajpoor I came across a more recent FFT implementation called OTFFT that claims to be faster than FFTW and has a more generous MIT license. However, I tried to implement the basic
fftfunction in Python but haven't been able to get the same answer asscipy.fft.fft. Here's what I did (List-7: Final version of the Stockham Algorithm):import math import cmath import numpy as np def fft0(n, s, eo, x, y): if not math.log2(n).is_integer(): # Check if n is power of 2 pass m = n // 2 theta0 = 2 * math.pi / n if n == 1: if eo: for q in range(s): y[q] = x[q] else: for p in range(m): wp = complex(math.cos(p*theta0), -math.sin(p*theta0)) for q in range(s): a = complex(x[q + s*(p + 0)]) b = complex(x[q + s*(p + m)]) y[q + s*(2*p + 0)] = a + b y[q + s*(2*p + 1)] = (a - b) * wp fft0(n//2, 2*s, not eo, y, x) def fft(n, x): y = np.empty(n, dtype=complex) fft0(n, 1, False, x, y) for k in range(n): x[k] /= nWould you mind taking a look? Maybe I messed up somewhere but I've been staring at it for too long and I'm not able to spot anything. Thanks in advance!
It turns out that x will be output if we just avoid dividing it by n.
def fft0(n, s, eo, x, y):
if not math.log2(n).is_integer(): # Check if n is power of 2
pass
m = n // 2
theta0 = 2 * math.pi / n
if n == 1:
if eo:
for q in range(s):
y[q] = x[q]
else:
for p in range(m):
wp = complex(math.cos(p*theta0), -math.sin(p*theta0))
for q in range(s):
a = complex(x[q + s*(p + 0)])
b = complex(x[q + s*(p + m)])
y[q + s*(2*p + 0)] = a + b
y[q + s*(2*p + 1)] = (a - b) * wp
fft0(n//2, 2*s, not eo, y, x)
# I swapped the params of `fft` function to make its signature similar to `scipy.ftt.ftt`
def fft(x, n):
y = np.empty(n, dtype=complex)
fft0(n, 1, False, x, y)
return x
And, to test it:
for power in range(1, 10):
n = 2 ** power
x = np.random.rand(n).astype(complex)
ref = scipy.fft.fft(x)
np.testing.assert_almost_equal(ref, fft(x, n))
It turns out that x will be output if we just avoid dividing it by n.
Hmmm, I wonder why they performed the division?! Thanks for figuring it out. I just ported it over blindly without trying to understand it 🤣.
How about the ifft?
def ifft(n, x):
for p in range(n):
x[p] = x[p].conjugate()
y = np.empty(n, dtype=complex)
fft0(n, 1, False, x, y)
# for k in range(n):
# x[k] = x[k].conjugate()
This doesn't seem to match scipy.fft.ifft either.
Would you mind doing a performance comparison if you are able to crack this?
@seanlaw
It turns out that x will be output if we just avoid dividing it by n.
Hmmm, I wonder why they performed the division?! Thanks for figuring it out. I just ported it over blindly without trying to understand it 🤣.
How about the
ifft?def ifft(n, x): for p in range(n): x[p] = x[p].conjugate() y = np.empty(n, dtype=complex) fft0(n, 1, False, x, y) # for k in range(n): # x[k] = x[k].conjugate()This doesn't seem to match
scipy.fft.iffteither.Would you mind doing a performance comparison if you are able to crack this?
This should work:
def _ifft(x):
n = len(x) # assuming `n` is power of two
x[:] = np.conjugate(x)
y = np.empty(n, dtype=np.complex128)
fft0(n, 1, False, x, y)
return np.conjugate(x / n)
I am working on some minor changes to boost the performance. I will share the performance of both original version, and the enhanced version, and will compare them with the core.sliding_dot_product.
For now, I did some enhancements on the new fft / ifft functions suggested in https://github.com/TDAmeritrade/stumpy/issues/938#issuecomment-1865067417.
Part (I):
I show the performance of four versions against the performance of our reference version, i.e. core.sliding_dot_product. The output of each version is tested to make sure that the function works correclty. The length of time series T is $2^{15}$, and the length of query is in range(10, 1000 + 10, 10).
These are the description of the four versions:
v0 --> use the new functions fft and ifft
v1 --> v0 + Reused the already-allocated memory `y`
v2 --> v1 + Converted the inner for-loop of `fft` function to a numpy vectorized operation
v3 --> v2 + Added njit decorator with `fastmath=True`
v4 --> v3 + Parallelized the outer for-loop of `fft` function
For a time series with length $2^{15}$, it seems that there is not much difference between v0 and v1. The changes in v2 and v3 seem to be very effective. How about v4? To better demonstrate its impact, in figure below, I am showing the performance of v3, v4, and the reference only.
And, we can zoom in further by removing the v3 from the figure. Then, we will see:
Part (II):
We now show how the v4 performs against the ref (i.e. core.sliding_dot_product) for different length of time series T.
As observed, the gap in the performance becomes bigger as the length of time series increases.
The code is available in the notebook pushed to this PR #939.
Next steps: (1) Make the code cleaner. (2) Profile the function to see where that increase in the performance gap comes from. (3) Optimize accordingly.
@seanlaw
What do you think?
Also: If I need to add/revise a step, please let me know.
If I understand correctly, the stockham algorithm is NOT faster than scipy.fft.convolve (or they are about the same after some optimizations). Is that correct? And it also means that the stockham algo is much slower than FFTW?
I wonder if there might be some clues in the OTFFT source code in terms of how they might have parallelized it using OpenMP. I'd be pretty happy if we could get within 2x (slower) than FFTW. I'm also guessing that the sawtooth shape observed in scipy.signal.convolve likely comes from switching between two FFT algorithms. It's reassuring to see that OTFFT is pretty stable in performance across different distances. I'm confused as to why using prange wouldn't give you the necessary speedup but that also depends on the hardware that you are using. Maybe I can find some time to test it out on my Mac with many threads.
If I understand correctly, the stockham algorithm is NOT faster than
scipy.fft.convolve(or they are about the same after some optimizations). Is that correct? And it also means that the stockham algo is much slower than FFTW?
Yes. After some initial optimizations, I can see that the stockham algorithm (from OTFFT) is slower. So, as you mentioned:
- Python Stockham algo (from OTFFT) is slower than
scipy.fft.convolve. scipy.fft.convolveis slower than MATLAB FFTW.
I wonder if there might be some clues in the OTFFT source code in terms of how they might have parallelized it using OpenMP. I'd be pretty happy if we could get within 2x (slower) than FFTW.
I will try to go through it to get some idea. Need to run the tests on MATALB online server if we want to consider MATLAB FFTW as our benchmark.
I'm confused as to why using
prangewouldn't give you the necessary speedup but that also depends on the hardware that you are using.
I think it gave us some boost. See the figure below...
v4 is the same as v3 but with this difference that it uses prange.
numba. get_num_threads() is 8 in my macOS system.
It appears that maybe we should consider implementing the six step or eight step FFT algorithm next as it should have much better memory locailty and is therefore "optimized". I'd expect this to be faster than our current sliding dot product. I'm not sure how/if any of the radix variants (shown at the same link above) will help.
we should consider implementing the six step or eight step FFT algorithm next as it should have much better memory locailty and is therefore "optimized"
I have implemented six-step-FFT. I will work on eight-step FFT algorithm.
[Note to self] Before I forget, here are a couple of notes that I ~may consider~ need to revisit later:
- numpy vectorized operation seems to increase the running time (?!)
- Turning off parallelization may speed up the computation (?!)
I have implemented six-step-FFT. I will work on eight-step FFT algorithm.
In case it matters, section 6.3 might be relevant as it discusses Stockham and the 6-step algo. More importantly, it describes how/why cache memory is important
According to this Mathworks webpage, MATLAB is equipped with built-in multithreading.
The built-in multithreading feature in MATLAB automatically parallelizes computations using underlying libraries at runtime, resulting in significant performance improvements for linear algebra and numerical functions such as fft, eig, mldivide, svd, and sort. Since MATLAB does this implicitly, you do not need to enable it manually. If you would like to control the number of computational threads, you can use the maxNumCompThreads function.
A few notes:
(1) As suggested, we can use maxNumCompThreads to set the thread to 1, and then compare MATLAB code (for sliding dot product) with stumpy.core.sliding_dot_product. we use MATLAB online server.
(2) We can implement 6-step / 8-step fft, and parallelize it, and then compare it with MATLAB code. we use MATLAB online server.
(3) should we expect the 6-step and 8-step implementation (without multithreading) be faster than the scipy.fft.fft ? The reason that I am asking this question is because scipy.fft.fft is written in C++ and I was wondering if it makes sense to expect it to be outperformed by 6-step / 8-step fft written in python? If not, then we can just stick to (2), and see how it performs compared to the MATLAB code.
should we expect the 6-step and 8-step implementation (without multithreading) be faster than the scipy.fft.fft ?
Based on these performance results I expect the njit version of the 6/8 step algo to be faster than scipy.fft.fft and, possibly, as fast as FFTW.
[Update] I implemented the 8-step algo as well. I ran it on MATLAB online server, and got this:
Based on these performance results I expect the
njitversion of the 6/8 step algo to be faster thanscipy.fft.fftand, possibly, as fast as FFTW.
I will go through section 6.3 and check the algo again to better understand it, and see if I am doing something wrong in my code. I may start with a notebook, and implement each function in one cell to go through them together with you, @seanlaw .
Also, need to take a look at the OTFFT source code
I implemented the 8-step algo as well. I ran it on MATLAB online server, and got this:
This is 8-step with njit and parallel=True? Any idea if the 8-step is faster than scipy?
This is 8-step with njit and parallel=True?
Yes, and yes
Any idea if the 8-step is faster than scipy?
I think scipy.fft is slightly faster than my current implementation of 6-step / 8-step algo, in my Mac.
Note that 8-step is for $2^{odd}$, and 6-step is for $2^{even}$. (Btw, the former calls the latter twice). Instead of sliding dot product, I am going to show the performance for fft function. I tested the performance for time series T with length that is power of two, and is from $2^{1}$ to $2^{24}$.
See figure below. ref is scipy.fft.fft, and comp is our fft function, with njit and parallel=True, which calls 6-step algo when the power is even, and calls 8-step, when the power is odd.
[Update-1] After reducing number of redundant arithmetic operations...I got this:
[Update-2]
I did another minor enhancements. Also, I am now checking the performance of fft of `T with length $2^{27}$. Furthermore, I now run each case 21 times, and get its median as the running time.
It's surprising that it jumps so dramatically at 2^24 to 2^27. I wonder if FFTW suffers from this too or if it still scales linearly. Technically, this algo should be n*logn in computational complexity.
It's surprising that it jumps so dramatically at 2^24 to 2^27. I wonder if FFTW suffers from this too or if it still scales linearly. Technically, this algo should be n*logn in computational complexity.
I do not know what is happening there. I wanted to try 2^28 but I noticed it takes too much time so I just stopped at 2^27. I Will check the behaviour of MATLAB FFTW for large-size inputs.
As suggested by @seanlaw in https://github.com/TDAmeritrade/stumpy/pull/939#discussion_r1438665912, we are going to take advantage of rfft since the input is a real-valued data. Based on the algo provided in 2.6.2 and summarized in https://github.com/TDAmeritrade/stumpy/pull/939#discussion_r1438801994, I implemented an initial version of fft that takes advantage of rfft.
In the following figure, ref is scipy.fft.fft, and comp is a njit with parallel=True function that takes advantage of rfft
I will work on finding opportunities to improve the performance.
According to https://github.com/TDAmeritrade/stumpy/issues/938#issuecomment-1868453389, we might be able to just use rfft for computing the sliding-dot-product (see function below). So, we do not need to construct the fft. We just need to implement the equivalent of np.fft.irfft
def sliding_dot_product_v3(Q, T):
n = len(T)
X = np.fft.rfft(T) * np.fft.rfft(np.flipud(Q), n=n)
out = np.fft.irfft(X, n=n)
return out[len(Q) - 1 :]
Will check the behaviour of MATLAB FFTW for large-size inputs.
It seems that MATLAB FFTW shows similar behaviour! (see below)
Also, the following figures confirm that six-step / eight-step FFT outperform FFTW in MATLAB
Let's zoom in for log2(len(T)) in range(2, 16):
How about p in range(2, 21)?
MATLAB
# MATLAB
fft_running_time = [];
pmax = 27;
rng(1,"twister");
T = rand(1, 2^pmax);
save('T_input.mat', 'T'); # we will use this for performance of our FFT
p_list = 2:1:pmax;
for p = p_list
idx = 2^p;
data = T(1, 1:idx);
t = tic;
X = fft(data);
fft_running_time(end+1) = toc(t);
end
Python
# Python
fft(np.random.rand(4), is_rfft=False, y=None) # dummy
fft_running_time_python = []
T = np.array(loadmat('./T_input.mat')['T'][0]).astype(np.float64)
p_list = range(2, 28)
for p in p_list:
idx = 2 ** p
data = T[:idx]
tic = time.time()
X = fft(data, is_rfft=False, y=None)
toc = time.time()
fft_running_time_python.append(toc - tic)
The fft function used here takes advantage of rfft but returns full-length fft at the end. The fft python code used for this exact performance comparison is provided here:
https://gist.github.com/NimaSarajpoor/5cf2d4f0b0aad5e0898651dddff35c17
@seanlaw I created and pushed a new notebook to PR #939 , I am now thinking of adding functions from the link above to it, one by one. What do you think?
Maybe we just ignore the recent pushes to the new notebook, and I start over by adding one function at a time.
In case that matters, I also checked the performance of scipy.fft.fft using MATLAB online for the same input data used in my previous comment.
Also, the following figures confirm that six-step / eight-step FFT outperform FFTW in MATLAB
Amazing! Can you confirm whether MATLAB and 6/8-step Python are single threaded or multithreaded?
I created and pushed a new notebook to PR https://github.com/TDAmeritrade/stumpy/pull/939 , I am now thinking of adding functions from the link above to it, one by one. What do you think?
Please give me some time to review and respond
Can you confirm whether MATLAB and 6/8-step Python are single threaded or multithreaded?
Before I provide some answers, I should mention that I got slightly different result!
Short answer: They were multithreaded.
Long answer: They were multithreaded. I think it is important to set the number of threads (Not sure if I made any mistake or the online server behaves differently before but I think we should/can trust the following results this time as I am explicity set the number of threats in both MATLAB code and Python code)
To make sure I am not making any accidental mistake, I got the performance for three cases: N_THREADS in {1, 4, 8}
To manage threads:
N = 8 # NUMBER OF THREADS I WOULD LIKE TO USE
# in MATLAB
LASTN = maxNumCompThreads(N);
# in Python
import numba
numba.set_num_threads(N)
I first show the performance of FFT in MATLAB for the three cases:
And, now I show the performance of FFT in Python for the same three cases:
So, we can see that the number of threads affects the performance.
Okay, let's compare MALTAB vs Python. To show this, I subtract the running time of Python from its corresponding running time via MATLAB code. So a positive y-value means that the Python code is faster than its MATLAB.
It seems that as number of threads increase, the six-step / eight-step FFT shows better performance!
Btw, let's zoom in for range(5, 21)
And MATLAB-vs-Python when we have 8 threads:
Okay... I tried to check the performance again. I ran each case for 50 times and then took the average. The following cases are considered:
- MATLAB fft
- Python six-step / eight-step fft
- Python six-step / eight-step rfft
- Scipy fft
- Scipy rfft
And, for each of the cases above, I considered different scenarios: 1 thread / 4 thread / 8 thread.
The results are provided below. I decided to not zoom in for each figure so that I can keep this comment easy to follow. In the following figures, 6fft refers to 6-step / 8-step fft
MATLAB
Python vs MATLAB [with 1 thread]
When we have one thread only, scipy fft / rfft wins.
Python vs MATLAB [with 8 thread]
Note that scipy is not parallelized. But, just to be consistent, I added it to the figure for the sake of comparison.
SHOW the diff: Python vs MATLAB [with 8 thread]
In the following figure, the red plot shows MATLAB running time minus the fft running time; And, the blue plot shows MATLAB running time minus the 6fft running time
[Update]
And if I zoom into the figure above to better see the comparison for time series with length <= 2 ^ 20, I will get this:
which shows MATLAB FFT performs better when length is <= 2 ^ 20
Thanks @NimaSarajpoor. Can you please provide the raw timing data in a table? I can see that we are doing better at longer time series lengths and with 8 threads. This is great. However, I think the majority of use cases will likely be in the 2^20 range (what do you think?) and so I'd like to see how close we are to MATLAB-FFTW. If we are several magnitudes off (in the <= 2^20 range) in timing then we still have work to do. If we are within 2x then I am okay with that.
However, I think the majority of use cases will likely be in the 2^20 range (what do you think?)
Right! So, it is important to have good performance in that range.
Can you please provide the raw timing data in a table?
TL; DR: It does not look good 😄
Sure. The table is provided below. The numbers are rounded to seven decimal places. I also added columns to show the ratio. If it too much, you may just take a look at the figure provided the bottom of this comment (The figure shows the 6fft / MATLAB as well as the 6rfft / MATLAB ratio.)
| Log2 of len(T) | Python fft / MATLAB | Python rfft / MATLAB | MATLAB | Python fft | Python rfft |
|---|---|---|---|---|---|
| 2 | 0.0327306 | 0.0265907 | 0.0004038 | 1.32E-05 | 1.07E-05 |
| 3 | 13.3032982 | 13.4683572 | 2.1E-06 | 2.77E-05 | 2.8E-05 |
| 4 | 31.7814838 | 31.6292407 | 1.7E-06 | 5.28E-05 | 5.25E-05 |
| 5 | 8.3517041 | 8.3672224 | 3.4E-06 | 2.82E-05 | 2.83E-05 |
| 6 | 5.0265435 | 5.2300492 | 1.04E-05 | 5.22E-05 | 5.43E-05 |
| 7 | 5.1188469 | 5.2070618 | 6E-06 | 3.07E-05 | 3.12E-05 |
| 8 | 4.7836901 | 4.7770128 | 1.21E-05 | 5.81E-05 | 5.8E-05 |
| 9 | 2.555058 | 2.4614294 | 1.43E-05 | 3.64E-05 | 3.51E-05 |
| 10 | 2.4825403 | 2.6721459 | 2.62E-05 | 6.5E-05 | 7E-05 |
| 11 | 1.49305 | 1.3197951 | 3.09E-05 | 4.61E-05 | 4.08E-05 |
| 12 | 2.3069466 | 2.1782147 | 4.06E-05 | 9.36E-05 | 8.83E-05 |
| 13 | 1.2975097 | 1.3393951 | 6.1E-05 | 7.92E-05 | 8.17E-05 |
| 14 | 1.8206379 | 1.7744369 | 9.97E-05 | 0.0001815 | 0.0001769 |
| 15 | 1.3474092 | 1.2914043 | 0.0001739 | 0.0002343 | 0.0002245 |
| 16 | 1.4363382 | 1.4223928 | 0.0003683 | 0.0005289 | 0.0005238 |
| 17 | 1.4568064 | 1.4602421 | 0.0006648 | 0.0009685 | 0.0009708 |
| 18 | 2.3318331 | 2.2741222 | 0.000938 | 0.0021872 | 0.002133 |
| 19 | 1.9355066 | 1.8724722 | 0.001927 | 0.0037296 | 0.0036082 |
| 20 | 1.7575101 | 2.1766148 | 0.0038541 | 0.0067736 | 0.0083889 |
| 21 | 0.3776821 | 0.5315373 | 0.0301298 | 0.0113795 | 0.0160151 |
| 22 | 0.5337859 | 0.6022327 | 0.0620284 | 0.0331099 | 0.0373555 |
| 23 | 0.4275656 | 0.5058924 | 0.160916 | 0.0688022 | 0.0814062 |
| 24 | 0.9009623 | 0.8916813 | 0.2904172 | 0.261655 | 0.2589596 |
| 25 | 0.7019051 | 0.5632563 | 0.4811759 | 0.3377398 | 0.2710254 |
| 26 | 1.0531359 | 0.9573159 | 1.1281815 | 1.1881284 | 1.0800261 |
| 27 | 0.6763958 | 0.6486571 | 2.3729286 | 1.6050389 | 1.5392169 |
Maybe I am not taking advantage of parallelization in the proper way. I am thinking of profiling the blocks of code in this function, and see if it can give me some clue. Do you think it is a good approach?
[Update] I think that I need to dig into section 7. And then check the source code. I have not checked it yet.
If it too much, you may just take a look at the figure provided the bottom of this comment
This is great! Thank you
TL; DR: It does not look good
Hmm, maybe I am misunderstanding the data but I'm not sure I agree. For the most part, it looks like we are within around 2x of MATLAB for lengths longer than 2^9. This is great and I do not expect to beat MATLAB/FFTW without a ton of fine tuning. Sure, for lengths that are shorter than 2^9, MATLAB is running at 10^-6 and Python is running at 10^-5 but this is basically instantaneous. Yes, the time can add up if we need to do millions+/billions+ of FFT calculations but I am very optimistic as we are actually much, much better than I had expected (thanks to your work)!
I have a question. Why is Python rfft not faster than Python fft?
I think that I need to dig into section 7 And then check the source code. I have not checked it yet.
For lengths shorter than 2^20, my suspicion is that the majority of our Python time is spent in non-FFT overhead (e.g., creating an array). I will take a look