DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Pre-compute `Curve` and `Surface` transforms in objectives that compute the magnetic field

Open ddudt opened this issue 1 year ago • 4 comments

QuadraticFlux.compute calls _Coil.compute_magnetic_field, which calls Curve.compute. Then the error is being thrown at the line:

if transforms is None:
    transforms = get_transforms(
        names, obj=self, grid=grid, jitable=True, **kwargs
    )

You fixed the issue when creating these transforms inside a jitable function. But it is wasteful to keep re-building these same transforms each time the objective is computed. It would be safer and more efficient to pre-compute these transforms in QuadraticFlux.build and then pass them through to the relevant compute functions. That would also fix this bug and speed up the code.

Originally posted by @ddudt in https://github.com/PlasmaControl/DESC/pull/1069#pullrequestreview-2138807945

ddudt avatar Jun 26 '24 17:06 ddudt

Thinking of an elegant way to do this is a bit difficult, as the compute_magnetic_field function lies outside of the normal compute hierarchy that is used when we call get_transforms. As the code stands right now, we would have to put some clunky logic inside the build of every single objective function where compute_magnetic_field is used to check if the field passed in is a Coil or not, and if it is, to compute the required transforms for it.

dpanici avatar Jun 26 '24 17:06 dpanici

I agree this is something we should figure out longer term, but as @dpanici points out it's actually a non-trivial amount of work to get it to work well. FWIW, all the transforms used by coils and magnetic field classes are just fourier series so are basically free to compute, so im not even sure if it saves that much time to precompute them. Also, in theory JAX may optimize that away as long as the grids are static.

f0uriest avatar Jun 26 '24 18:06 f0uriest

Also edited title as we should be able to pre-compute for Surface objects as well FourierCurrentPotentialField, though again this is not super straightforward to do elegantly

dpanici avatar Jul 02 '24 14:07 dpanici

Just to leave a further note:

It is not hard to do this on an objective-by-objective basis for lone magnetic field objects, as we can simply check what the type of object is (FourierRZCoil, FourierCurrentPotentialField, ToroidalMagneticField, etc) and then create the transform with the appropriate grid for that object. However this is

  • a) cumbersome, even if we put this in a util function it still would be a lot of if statements, and
  • b) this won't work for MixedCoilSet or SumMagneticField objects, where it matters internally which object gets which grid, and would require some pyTree structure like the _Coil objectives have.

One way forward could be to have a utility function which

  • takes in the magnetic field object
  • uses if statements to make the appropriate grid for it
  • calls get_transforms with that grid and returns the transform dict for that object

Then use a PyTree utility to apply this utility function to each leaf of the given MagneticField object (each leaf being any MagneticField object which is not a MixedCoilSet or SumMagneticField) and get back the correctly pre-computed transforms for that MagneticField PyTree object.

The ideal way which avoids all the if statements and the necessity of a utility function to begin with would be to incorporate the B computation from compute_magnetic_field into the normal compute method, and add the MagneticField classes as parameterizations in the data index. But again this will also be a bit of work, just wanted to lay out the ideas

dpanici avatar Jul 02 '24 15:07 dpanici