DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Poloidal FFT Implementation

Open dpanici opened this issue 10 months ago • 4 comments

Initial implementation works for Fourier, DoubleFourier, and FourierChebyshev basis, but not Zernike yet

  • [ ] add poloidal derivatives
  • [ ] Figure out zernike - I think we need to transform the radial part first? its weird because that couples L and M
  • [ ] Add stuff for projection / fit / inverse transform using poloidal fft
  • [ ] Add more tests

Resolves #641

dpanici avatar Jan 06 '25 19:01 dpanici

may need to change code to move grid checks after #1405

dpanici avatar Jan 08 '25 18:01 dpanici

(if you're hooking in ffts for chebyshev double fourier, the order for transform from real to spectral should coeffs = rfft2(dct(f)) since dct returns a real output).

unalmis avatar Apr 02 '25 19:04 unalmis

  • Add util for converting to numpy fft basis
  • make ptolemy identity work with JAX, also matrix free method?

dpanici avatar Apr 16 '25 20:04 dpanici

for the current ptolemy identity stuff we use some dynamic sizing and np.where operations that won't work with jax. I think one possible fix would be to pass in a Basis instead of just the modes array. Then eg Basis.N is statically known, so we can always do

modes_full = jnp.arange(-basis.N, basis.N+1)
x_full = jnp.zeros(2*basis.N+1)
x_full = x_full.at[basis.modes[:,2]].set(x)

then we have a statically sized array with all of the mode numbers and coefficients that can more easily be manipulated without conditionals

f0uriest avatar Apr 18 '25 00:04 f0uriest

To get this to work properly I think we decided the best option would be to have a way to convert our fourier representation to the vmec one, since that can be directly used with 2d ffts. I played around with a few different ways to do this here: https://gist.github.com/f0uriest/b12be9d5f22b845463158793fefba82b

Here's some timing for a DoubleFourierBasis with M=N=20

# existing method, np only, for comparison
%timeit _, _, s2, c2 = desc.vmec_utils.ptolemy_identity_rev(basis.modes[:,1], basis.modes[:,2], x[None])
# 297 ms ± 29.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _, _, s3, c3 = jax.block_until_ready(ptolemy_identity_rev_jax(basis, x))
# 30.2 ms ± 2.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _, _, s3, c3 = jax.block_until_ready(ptolemy_identity_rev_jax_vmap(basis, x))
# 9.38 ms ± 1.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

However, an even better option seems to be to just precompute the linear transform matrix eg

A = jnp.array(desc.vmec_utils.ptolemy_linear_transform(basis.modes, vmec_modes=None, helicity=None, NFP=None)[0])
%timeit _ = jax.block_until_ready(jnp.dot(A, x))
# 1.01 ms ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

And we can do even better if we use a sparse representation for A:

import jax.experimental.sparse
As = jax.experimental.sparse.BCOO.fromdense(A)

@jax.jit
def spdot(A, x):
    return A@x

%timeit _ = jax.block_until_ready(spdot(As, x))
# 32.4 μs ± 1.68 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

For pretty much all resolutions, just multiplying by A seems to be the fastest option, and for high res using a sparse A seems the best bet.

f0uriest avatar Jun 04 '25 21:06 f0uriest

So it from some profiling it seems like using FFTs might not even be worth it compared to just using a factorized direct method?

for M in [5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100]:
    m = 2*M+1
    x = np.random.random((m,m,m))
    A = np.random.random((4*M+1,m))
    print("M=",M)
    print("dot time")
    %timeit _ = jnp.dot(A, x).block_until_ready()
    print("fft time")
    %timeit _ = jax.jit(jnp.fft.ifft, static_argnames=["n", "axis"])(x, n=4*M+1, axis=0).block_until_ready()
M= 5
dot time
23.7 μs ± 2.83 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fft time
93.9 μs ± 1.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
M= 10
dot time
77.4 μs ± 7.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fft time
433 μs ± 162 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
M= 15
dot time
263 μs ± 12 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
fft time
1.99 ms ± 37.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
M= 20
dot time
377 μs ± 12.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
fft time
2.14 ms ± 39.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
M= 25
dot time
824 μs ± 17 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
fft time
12.2 ms ± 297 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
M= 30
dot time
1.53 ms ± 18 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
fft time
7.17 ms ± 65.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
M= 40
dot time
4.25 ms ± 339 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fft time
20.6 ms ± 4.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
M= 50
dot time
6.65 ms ± 76.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fft time
53.6 ms ± 6.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
M= 60
dot time
10.8 ms ± 94.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fft time
87.5 ms ± 4.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
M= 70
dot time
18.3 ms ± 244 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fft time
169 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
M= 80
dot time
34.7 ms ± 716 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
fft time
231 ms ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
M= 90
dot time
82.4 ms ± 622 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
fft time
259 ms ± 47.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
M= 100
dot time
120 ms ± 3.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
fft time
417 ms ± 84.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

though I'm a bit suspicious that the FFT doesn't get faster even at crazy high resolutions. Can someone else run this and let me know?

f0uriest avatar Jun 11 '25 17:06 f0uriest

Won't the speed improve a bunch if we do rfft2(dct(f, radial axis), flux axes) and the inverse idct(irfft2)

unalmis avatar Jun 11 '25 17:06 unalmis

This seems to suggest the opposite. Factorizing the 3d transform into 3 1d transforms is definitely worth it, but doing fft/dct doesn't seem to be any faster than matrix multiplication for the sizes we care about (and much larger sizes as well?)

f0uriest avatar Jun 11 '25 17:06 f0uriest

I recall vaguely similar things when we first implemented the toroidal fft. The biggest savings came from just factorizing the 3d transform into 2d + 1d, the whether we used MMT or FFT for the 1d part didn't seem to make much difference:

for M in [5, 10, 15, 20, 25]:
    basis = FourierZernikeBasis(M,M,M)
    grid = QuadratureGrid(M,M,M)
    transform1 = Transform(grid, basis, method="direct1", build=True)
    transform2 = Transform(grid, basis, method="direct2", build=True)
    transform3 = Transform(grid, basis, method="fft", build=True)
    t1 = jax.jit(transform1.transform)
    t2 = jax.jit(transform2.transform)
    t3 = jax.jit(transform3.transform)
    x = np.random.random(basis.num_modes)
    _ = t1(x).block_until_ready()
    _ = t2(x).block_until_ready()
    _ = t3(x).block_until_ready()
    print("M=",M)
    print("direct1")
    %timeit _ = t1(x).block_until_ready()
    print("direct2")
    %timeit _ = t2(x).block_until_ready()
    print("fft")
    %timeit _ = t3(x).block_until_ready()
M= 5
direct1
24.1 μs ± 2.09 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
direct2
17.5 μs ± 2.69 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fft
21.4 μs ± 1.32 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
M= 10
direct1
1.23 ms ± 37.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
direct2
39.6 μs ± 6.93 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fft
50.7 μs ± 7.04 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
M= 15
direct1
8.22 ms ± 156 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
direct2
272 μs ± 4.14 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
fft
229 μs ± 10.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
M= 20
direct1
42.6 ms ± 363 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
direct2
365 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
fft
393 μs ± 7.31 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
M= 25
direct1
161 ms ± 2.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
direct2
573 μs ± 10.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
fft
658 μs ± 11.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

f0uriest avatar Jun 11 '25 17:06 f0uriest

Also note that the time cost of converting our double fourier basis to the form expected by np.fft2 is comparable to the cost of doing the fft. Plus, in the common case that we have poloidal up/down symmetry and we use a symmetric grid, I don't think there's any easy way to do FFT's in that case, other than doing the full domain and then truncating.

All that said, I think it probably makes more sense to just do direct matrix transforms poloidally. We can still get some speedup by separating the radial/poloidal parts, but still using dense matrices for both.

f0uriest avatar Jun 11 '25 17:06 f0uriest

@dpanici @YigitElma Try this on GPUs and profile direct3 against the other methods

dpanici avatar Jun 18 '25 20:06 dpanici

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    1.17 %    |     3.847e+03      |     3.892e+03      |    44.92     |       33.66        |       32.67        |
+ test_proximal_jac_w7x_with_eq_update   |   -17.44 %   |     6.961e+03      |     5.747e+03      |   -1214.26   |       73.40        |       168.35       |
+ test_proximal_freeb_jac                |   -98.09 %   |     1.319e+04      |     2.524e+02      |  -12935.98   |        8.63        |       79.45        |
+ test_proximal_freeb_jac_blocked        |   -96.48 %   |     7.583e+03      |     2.666e+02      |   -7316.73   |        8.63        |       70.72        |
+ test_proximal_freeb_jac_batched        |   -96.38 %   |     7.631e+03      |     2.764e+02      |   -7354.12   |        8.53        |       72.63        |
+ test_proximal_jac_ripple               |   -62.17 %   |     7.563e+03      |     2.861e+03      |   -4702.18   |       28.73        |       74.57        |
  test_proximal_jac_ripple_spline        |    2.52 %    |     3.471e+03      |     3.558e+03      |    87.54     |       75.77        |       75.58        |
  test_eq_solve                          |    1.25 %    |     2.047e+03      |     2.073e+03      |    25.63     |       130.36       |       130.37       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

github-actions[bot] avatar Jul 29 '25 01:07 github-actions[bot]