DESC
DESC copied to clipboard
Coil-Coil & Plasma-Coil Minimum Distance Objectives
- Adds the objective
CoilsetMinDistance, which returns the minimum distance to another coil for each coil in a coilset. Resolves #898 - Adds the objective
PlasmaCoilsetMinDistance, which returns the minimum distance to the plasma surface for each coil in a coilset. Resolves #900
Dependencies:
- #1016
- #1017
- #1018
TODO:
- [x] unit tests for
CoilsetMinDistance - [x] unit tests for
PlasmaCoilsetMinDistance - [x] regression test using both objectives in a coil optimization
- [ ] make it work with surface and/or equilibrium to resolve #947
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 94.97%. Comparing base (
80acaf9) to head (2df3a94). Report is 1822 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #977 +/- ##
==========================================
+ Coverage 94.91% 94.97% +0.05%
==========================================
Files 87 87
Lines 21584 21725 +141
==========================================
+ Hits 20487 20633 +146
+ Misses 1097 1092 -5
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/coils.py | 96.85% <100.00%> (+0.29%) |
:arrow_up: |
| desc/geometry/surface.py | 96.80% <ø> (ø) |
|
| desc/objectives/__init__.py | 100.00% <ø> (ø) |
|
| desc/objectives/_coils.py | 99.12% <100.00%> (-0.06%) |
:arrow_down: |
| desc/objectives/_geometry.py | 96.56% <100.00%> (+0.27%) |
:arrow_up: |
| desc/optimize/optimizer.py | 97.05% <100.00%> (+0.01%) |
:arrow_up: |
FYI Thomas is already working on this. Maybe we can combine his work with this.
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +6.76 +/- 8.03 | +3.42e-02 +/- 4.06e-02 | 5.40e-01 +/- 3.9e-02 | 5.06e-01 +/- 9.5e-03 |
test_build_transform_fft_midres | +6.65 +/- 3.87 | +3.92e-02 +/- 2.28e-02 | 6.28e-01 +/- 1.1e-02 | 5.89e-01 +/- 2.0e-02 |
test_build_transform_fft_highres | +4.13 +/- 4.97 | +4.04e-02 +/- 4.86e-02 | 1.02e+00 +/- 2.8e-02 | 9.78e-01 +/- 4.0e-02 |
test_equilibrium_init_lowres | +1.73 +/- 2.91 | +6.34e-02 +/- 1.07e-01 | 3.72e+00 +/- 7.8e-02 | 3.66e+00 +/- 7.3e-02 |
test_equilibrium_init_medres | +1.14 +/- 2.45 | +4.71e-02 +/- 1.01e-01 | 4.17e+00 +/- 9.7e-02 | 4.12e+00 +/- 2.8e-02 |
test_equilibrium_init_highres | +1.68 +/- 5.86 | +9.70e-02 +/- 3.38e-01 | 5.86e+00 +/- 1.6e-01 | 5.76e+00 +/- 3.0e-01 |
test_objective_compile_dshape_current | -6.89 +/- 5.34 | -2.84e-01 +/- 2.20e-01 | 3.84e+00 +/- 7.3e-02 | 4.12e+00 +/- 2.1e-01 |
test_objective_compile_atf | -1.50 +/- 4.30 | -1.26e-01 +/- 3.60e-01 | 8.26e+00 +/- 1.8e-01 | 8.39e+00 +/- 3.1e-01 |
test_objective_compute_dshape_current | -0.27 +/- 4.44 | -3.37e-06 +/- 5.60e-05 | 1.26e-03 +/- 3.5e-05 | 1.26e-03 +/- 4.4e-05 |
test_objective_compute_atf | -1.98 +/- 5.32 | -8.50e-05 +/- 2.29e-04 | 4.21e-03 +/- 1.7e-04 | 4.29e-03 +/- 1.5e-04 |
test_objective_jac_dshape_current | +1.44 +/- 5.87 | +5.50e-04 +/- 2.24e-03 | 3.87e-02 +/- 1.7e-03 | 3.81e-02 +/- 1.5e-03 |
test_objective_jac_atf | +0.18 +/- 2.44 | +3.47e-03 +/- 4.60e-02 | 1.89e+00 +/- 3.4e-02 | 1.89e+00 +/- 3.1e-02 |
test_perturb_1 | -5.57 +/- 2.65 | -7.69e-01 +/- 3.66e-01 | 1.30e+01 +/- 3.2e-02 | 1.38e+01 +/- 3.6e-01 |
test_perturb_2 | -4.98 +/- 3.07 | -9.45e-01 +/- 5.84e-01 | 1.80e+01 +/- 1.7e-01 | 1.90e+01 +/- 5.6e-01 |
test_proximal_jac_atf | -0.74 +/- 1.31 | -5.46e-02 +/- 9.62e-02 | 7.31e+00 +/- 7.7e-02 | 7.36e+00 +/- 5.8e-02 |
test_proximal_freeb_compute | -1.15 +/- 0.83 | -2.06e-03 +/- 1.49e-03 | 1.77e-01 +/- 8.8e-04 | 1.79e-01 +/- 1.2e-03 |
test_proximal_freeb_jac | -0.58 +/- 2.38 | -4.31e-02 +/- 1.76e-01 | 7.38e+00 +/- 6.6e-02 | 7.42e+00 +/- 1.6e-01 |
test_solve_fixed_iter | -1.07 +/- 11.13 | -1.60e-01 +/- 1.66e+00 | 1.48e+01 +/- 8.9e-01 | 1.50e+01 +/- 1.4e+00 |
I think this is probably a better implementation:
npts = 100
ncoil = 10
pts = jnp.array(np.random.random((ncoil, npts,3))) # or output of coilset.compute("x")
def body(i):
dx = pts[i] - pts
dist = jnp.linalg.norm(dx, axis=-1)
# ignore distance to pts within each coil
mask = jnp.ones(ncoil)
mask = mask.at[i].set(0)[:,None]
return jnp.min(dist, where=mask, initial=jnp.inf)
min_dist_per_coil = jax.lax.fori_loop(0, ncoil, lambda i, mind: mind.at[i].set(body(i)), jnp.zeros(ncoil))
# probably faster if there arent a ton of coils/pts:
min_dist_per_coil = jax.vmap(body)(jnp.arange(ncoil))
This returns an array of length(ncoil) where arr[i] is the minimum distance from coils[i] to any other coil in the set.
TODO for me: make sure it caluculates the x's of EVERY coil (not just unique ones) so that it can properly avoid self intersection across the stellarator symmetry reflection plane
For dealing with symmetry: I think the easiest option is to have CoilSet.compute return data for all the coils (ie via duplication, rotation, flipping etc), including virtual coils, similar to how CoilSet.compute_magnetic_field accounts for virtual coils. Then the existing logic should work fine.
Just do this for "x" in the objective function
I would have preferred keeping this PR just coil-coil distance...
This PR is ready except it is waiting on #1016 to finish adding/editing tests.
The test I added is failing at an unexpected place (when eq_fixed is False)...
It is a JAX error complaining about int I think, but JAX encounters an error when it tries to throw the error.
Issue is logged with JAX (https://github.com/google/jax/issues/20397#issuecomment-2143671178) and a PR is in that should let us better see the error (https://github.com/google/jax/pull/21567)
The test I added is failing at an unexpected place (when eq_fixed is False)...
It is a JAX error complaining about int I think, but JAX encounters an error when it tries to throw the error.
Issue is logged with JAX (google/jax#20397 (comment)) and a PR is in that should let us better see the error (google/jax#21567)
the bug on our part was using compute instead of _compute inside of _coils.py, changing that fixes the error (we had some jax-unfriendly checks inside of compute that we skip in _compute)