DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Making transforms more efficient

Open f0uriest opened this issue 1 year ago • 1 comments

Trying to coalesce a few things we've been discussing regarding making transforms more efficient.

1. Partial basis.evaluate

When calling basis.evaluate you should be able to specify which coordinates to actually evaluate. Default would be all of them, but evaluating only radial eg would make the next step easier. It would also modularize things better (right now we directly call zernike_radial several places where we really mean basis.evaluate but only in the r coordinate), making it easier to swap in different basis types (finite element, chebyshev, etc)

2. Partial transform.transform

Similarly, transform should allow you to transform only one or two coordinates at a time. For example, instead of lmn -> rtz you could do lmn -> rmn to only evaluate the radial part. This would mean breaking the transform matrices into individual parts, which we sort of already do because for FFTs we only use matrices for the radial/poloidal part. This should also make it easier to do poloidal FFTs (#641).

3. Allow you to promise "same nodes" when updating transform.grid

ie, transform.update_grid(new_grid, promise_same="r") you would be saying this grid is the same as the old one in the radial direction, so that you only need to recompute the poloidal and toroidal part of the transform matrices. So for example, in map_coordinates assuming rho is both an input and output coordinate:

transforms = get_transforms(names..., grid=initial_guess_grid)
rho = coords[:,0]

def rootfun(x):
    nodes = np.stack([rho, x]) # x is only theta/zeta
    new_grid = Grid(nodes)
    transforms.update_grid(new_grid, promise_same="r")
    data = _compute_fun(names, params, transforms, ...)

this is more along the lines of "partial evaluation" rather than "partial summation" but would likely see similar improvements (#1154)

4. Move logic from transform._check_inputs_fft to Grid and Basis

For most grid classes we don't really need to do these checks since we know by construction that certain things are/are not satisfied. This would simplify the transforms a bit and speed up construction a bit if we don't need to do a bunch of checks. (It would also be nice if we can get rid of some of the indexing operations there in favor of just reshaping, but im not sure how that plays with symmetry where certain modes don't exist)

5. Extend notion of "meshgrid" to partial meshgrids/bases

eg, ConcentricGrid has the same rho/theta at each zeta, so its sort of a tensor product grid (with some accompanying simplifications). Similarly, a FourierZer nikeBasis is a partial tensor product.

f0uriest avatar Sep 03 '24 15:09 f0uriest

Yes I think using partial summation techniques would be significant. Some thoughts

  • The given example in 3 may require #1207 first, otherwise the nodes would change in the Newton iteration while the transform won't get updated.
  • For 1. 2. and 3., the partial evaluation steps would be a constant factor improvement. The complexity of evaluating a function over a 3D grid is still $\mathcal{O}(N^6)$. Probably still useful
  • On a tensor-product grid, you can proceed from partial evaluation to partial summation mentioned in #1154 to reduce to $\mathcal{O}(N^4)$.
  • Some optimization metrics are 2D objectives over flux surfaces that implicitly rely on the continuity of the physics to minimize the objective throughout the volume. For these the best node set is probably a tensor product, so using partial summation will reduce the computation cost from $\mathcal{O}(LN^4) \to \mathcal{O}(LN^3)$ where $L$ is number of flux surfaces and $N$ is maximum of the number of poloidal and toroidal surfaces.
  • Might be better to always default to a tensor product node set (orthogonal polynomial in radial and fourier in poloidal and toroidal) rather than FourierZernike so that 3D FFTs and partial summation can be used. (In particular DPT in radial, real FFT in poloidal, and FFT in toroidal).
  • Lower priority, but some multigrid optimization only need one grid/ the transforms for that grid if some of the objectives only need to target a subset of the flux surfaces of the full grid. Right now we recreate transforms for each objective instead of using the subset of the transform on the full grid.

unalmis avatar Sep 05 '24 20:09 unalmis

As discussed in #1508 indexing operations are bottlenecks for the transforms. Below is brief recap and my suggested solution that will also greatly reduce/simplify DESC code.

  1. Recall there is significant indexing overhead of switching from DESC Fourier basis to the standard complex FFT basis. As mentioned here, this is a bottleneck in the runtime analysis since it takes the same time as an FFT. This further motivates #1531 . However, let's ignore this for now.
  2. There is also indexing overhead to pad a mode set with zeros to form a tensor product mode set. E.g. padding the Zernike or symmetric mode sets. This padding step is a fundamental requirement (if we want to avoid jax loops) for any algorithm that factorizes n-d transforms into lower dimension transforms such as partial summation which can be implemented with either an n-d factorized MMT or an n-d FFT. (The DESC direct3 method in 1508 corresponds to a factorized MMT).
  3. Hence, whether you use an MMT or FFT, you have to pad the basis being transformed to a tensor product prior to the transformation if you want to an efficient algorithm. Either method of implementing partial summation will be significantly faster. Frankly this suggests the best way to do transforms is to just store the tensor product mode set alone in the Basis class. For example, when there is some type of symmetry, don't delete those modes; instead make sure those spectral coefficients are somehow constrained to be zero. Likewise for Fourier Zernike, make the LM modes as a tensor product L*M and constrain the fake basis and the spectral coefficients of the fake modes to be zero.
  4. 3 would improve partial summation implementation as described in #1154 with FourierZernike

Right now we are doing 3 anyway in a roundabout way of deleting modes in basis, computing and storing the unique and inverse modes, doing indexing operations to evaluate each basis, then attempting to reconstruct the original tensor product mode set with a bunch of indexing when the transforms are called. This is the reason why our transforms.py class complicated when it could be as simple as this.

unalmis avatar Jul 29 '25 08:07 unalmis

As discussed in #1508 indexing operations are bottlenecks for the transforms. Below is brief recap and my suggested solution that will also greatly reduce/simplify DESC code.

  1. Recall there is significant indexing overhead of switching from DESC Fourier basis to the standard complex FFT basis. As mentioned here, this is a bottleneck in the runtime analysis since it takes the same time as an FFT. This further motivates Add Fourier series with complex exponential form #1531 . However, let's ignore this for now.
  2. There is also indexing overhead to pad a mode set with zeros to form a tensor product mode set. E.g. padding the Zernike or symmetric mode sets. This padding step is a fundamental requirement for any algorithm that factorizes n-d transforms into lower dimension transforms such as partial summation which can be implemented with either an n-d factorized MMT or an n-d FFT. (The DESC direct3 method in 1508 corresponds to a factorized MMT).
  3. Hence, whether you use an MMT or FFT, you have to pad the basis being transformed to a tensor product prior to the transformation if you want to an efficient algorithm. Either method of implementing partial summation will be significantly faster. Frankly this suggests the best way to do transforms is to just store the tensor product mode set alone in the Basis class. For example, when there is some type of symmetry, don't delete those modes; instead make sure those spectral coefficients are somehow constrained to be zero. Likewise for Fourier Zernike, make the LM modes as a tensor product L*M and constrain the fake basis and the spectral coefficients of the fake modes to be zero.
  4. 3 would improve partial summation implementation as described in Improve coordinate mapping performance #1154 with FourierZernike

Right now we are doing 3 anyway in a roundabout way of deleting modes in basis, computing and storing the unique and inverse modes, doing indexing operations to evaluate each basis, then attempting to reconstruct the original tensor product mode set with a bunch of indexing when the transforms are called. This is the reason why our transforms.py class complicated when it could be as simple as this.

A possible issue I see with the suggestion of keeping the full basis and just constraining certain ones to be zero based off symmetry (my understanding of 3, I am envisioning it as like a linear constraint) is the issues we already see with #782 , where just simply extending the x vector length with corresponding linear constraints that in the end don't add any new DOFs can have an effect on computation due to numerics of inverting the linear constraint matrix.

I agree this would be a clean way to do what you suggest but I just am slightly hesitant to increase the size of the pre-constrained x like that. Maybe there is a different way you had in mind though that would avoid this.

dpanici avatar Jul 29 '25 13:07 dpanici

then just store full tensor product mode set and a mask of the nonzero modes (instead of deleting them) in basis. Then if you need only nonzero modes in at the final optimization step just query nz_modes=modes[nonzero_mask].

unalmis avatar Jul 29 '25 16:07 unalmis