DESC
DESC copied to clipboard
JIT Error encountered when optimizing `GammaC`
Error seems to occur when optimizing GammaC
objective on gh/Gamma_c
branch, happens on the second optimization step and seems related to the JIT cache? The error also only occurs if attempting an optimization at a resolution that you have previously optimized at, changing the eq resolution between steps seems to avoid this issue, so I assume it is related to the caching
MWE:
from desc import set_device
set_device("gpu")
import jax
import numpy as np
import desc.examples
from desc.continuation import solve_continuation_automatic
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import ConcentricGrid, LinearGrid
from desc.io import load
from desc.objectives import ( # FixIota,
AspectRatio,
Elongation,
FixBoundaryR,
FixBoundaryZ,
FixCurrent,
FixPressure,
FixPsi,
ForceBalance,
GammaC,
GenericObjective,
ObjectiveFunction,
QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer
from desc.plotting import plot_boozer_surface
import pdb
from desc.backend import jnp
from desc.examples import get
def run_opt_step(k, eq):
"""Run a step of the optimization example."""
# this step will only optimize boundary modes with |m|,|n| <= k
# we create an ObjectiveFunction, in this case made up of multiple objectives
# which will be combined in a least squares sense
shape_grid = LinearGrid(
M=int(eq.M), N=int(eq.N), rho=np.array([1.0]), NFP=eq.NFP, sym=True, axis=False
)
ntransits = 8
zeta_field_line = np.linspace(0, 2 * np.pi * ntransits, 64 * ntransits)
alpha = jnp.array([0.0])
rho = jnp.linspace(0.85, 1.0, 2)
# rho = np.linspace(0.85, 1.0, 2)
flux_surface_grid = LinearGrid(
rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP
)
objective = ObjectiveFunction(
(
GammaC(
eq=eq,
rho=rho,
alpha=alpha,
deriv_mode="fwd",
batch=False,
weight=1e3,
Nemov = False,
),
Elongation(eq=eq, grid=shape_grid,target=1),#0 bounds=(0.5, 2.0), weight=1e3),
GenericObjective(
f="curvature_k2_rho",
thing=eq,
grid=shape_grid,
bounds=(-75, 15),
weight=2e3,
),
),
)
R_modes = np.vstack(
(
[0, 0, 0],
eq.surface.R_basis.modes[
np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
],
)
)
Z_modes = eq.surface.Z_basis.modes[
np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
]
constraints = (
ForceBalance(
eq,
grid=ConcentricGrid(
L=round(2 * eq.L),
M=round(1.5 * eq.M),
N=round(1.5 * eq.N),
NFP=eq.NFP,
sym=eq.sym,
),
),
FixBoundaryR(eq=eq, modes=R_modes),
FixBoundaryZ(eq=eq, modes=Z_modes),
FixPressure(eq=eq),
FixCurrent(eq=eq),
FixPsi(eq=eq),
)
# this is the default optimizer, which re-solves the equilibrium at each step
optimizer = Optimizer("proximal-lsq-exact")
eq_new, result = optimizer.optimize(
things = eq,
objective=objective,
constraints=constraints,
maxiter=3, # we don't need to solve to optimality at each multigrid step
verbose=3,
copy=True, # don't modify original, return a new optimized copy
options={
# Sometimes the default initial trust radius is too big, allowing the
# optimizer to take too large a step in a bad direction. If this happens,
# we can manually specify a smaller starting radius. Each optimizer has a
# number of different options that can be used to tune the performance.
# See the documentation for more info.
"initial_trust_ratio": 1e-2,
"maxiter": 125,
"ftol": 1e-3,
"xtol": 1e-8,
},
)
eq_new = eq_new[0]
return eq_new
eq = get("ESTELL")
for k in np.arange(1, eq.M + 1, 1):
if not eq.is_nested():
print("NOT NESTED")
assert eq.is_nested()
break
jax.clear_caches()
eq = run_opt_step(k, eq)
Error:
ValueError Traceback (most recent call last)
Cell In[1], line 137
135 break
136 jax.clear_caches()
--> 137 eq = run_opt_step(k, eq)
Cell In[1], line 107, in run_opt_step(k, eq)
103 optimizer = Optimizer("proximal-lsq-exact")
105 print("spot 1:", type(eq))
--> 107 eq_new, result = optimizer.optimize(
108 things = eq,
109 objective=objective,
110 constraints=constraints,
111 maxiter=3, # we don't need to solve to optimality at each multigrid step
112 verbose=3,
113 copy=True, # don't modify original, return a new optimized copy
114 options={
115 # Sometimes the default initial trust radius is too big, allowing the
116 # optimizer to take too large a step in a bad direction. If this happens,
117 # we can manually specify a smaller starting radius. Each optimizer has a
118 # number of different options that can be used to tune the performance.
119 # See the documentation for more info.
120 "initial_trust_ratio": 1e-2,
121 "maxiter": 125,
122 "ftol": 1e-3,
123 "xtol": 1e-8,
124 },
125 )
126 eq_new = eq_new[0]
128 return eq_new
File ~/DESC/desc/optimize/optimizer.py:311, in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
307 print("Using method: " + str(self.method))
309 timer.start("Solution time")
--> 311 result = optimizers[method]["fun"](
312 objective,
313 nonlinear_constraint,
314 x0,
315 method,
316 x_scale,
317 verbose,
318 stoptol,
319 options,
320 )
322 if isinstance(objective, LinearConstraintProjection):
323 # remove wrapper to get at underlying objective
324 result["allx"] = [objective.recover(x) for x in result["allx"]]
File ~/DESC/desc/optimize/_desc_wrappers.py:270, in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
267 options.setdefault("initial_trust_ratio", 0.1)
268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
271 objective.compute_scaled_error,
272 x0=x0,
273 jac=objective.jac_scaled_error,
274 args=(objective.constants,),
275 x_scale=x_scale,
276 ftol=stoptol["ftol"],
277 xtol=stoptol["xtol"],
278 gtol=stoptol["gtol"],
279 maxiter=stoptol["maxiter"],
280 verbose=verbose,
281 callback=None,
282 options=options,
283 )
284 return result
File ~/DESC/desc/optimize/least_squares.py:176, in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
173 assert in_bounds(x, lb, ub), "x0 is infeasible"
174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
177 nfev += 1
178 cost = 0.5 * jnp.dot(f, f)
File ~/DESC/desc/optimize/_constraint_wrappers.py:224, in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
208 """Compute the objective function and apply weighting / bounds.
209
210 Parameters
(...)
221
222 """
223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
225 return f
File ~/DESC/desc/optimize/_constraint_wrappers.py:843, in ProximalProjection.compute_scaled_error(self, x, constants)
841 constants = setdefault(constants, self.constants)
842 xopt, _ = self._update_equilibrium(x, store=False)
--> 843 return self._objective.compute_scaled_error(xopt, constants[0])
[... skipping hidden 6 frame]
File ~/.conda/envs/desc-env-latest/lib/python3.11/site-packages/jax/_src/pjit.py:1339, in seen_attrs_get(fun, in_type)
1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()