DESC icon indicating copy to clipboard operation
DESC copied to clipboard

JIT Error encountered when optimizing `GammaC`

Open dpanici opened this issue 4 months ago • 7 comments

Error seems to occur when optimizing GammaC objective on gh/Gamma_c branch, happens on the second optimization step and seems related to the JIT cache? The error also only occurs if attempting an optimization at a resolution that you have previously optimized at, changing the eq resolution between steps seems to avoid this issue, so I assume it is related to the caching

MWE:

from desc import set_device

set_device("gpu")

import jax
import numpy as np

import desc.examples
from desc.continuation import solve_continuation_automatic
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import ConcentricGrid, LinearGrid
from desc.io import load
from desc.objectives import (  # FixIota,
    AspectRatio,
    Elongation,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    GammaC,
    GenericObjective,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer
from desc.plotting import plot_boozer_surface
import pdb
from desc.backend import jnp
from desc.examples import get
def run_opt_step(k, eq):
    """Run a step of the optimization example."""
    # this step will only optimize boundary modes with |m|,|n| <= k
    # we create an ObjectiveFunction, in this case made up of multiple objectives
    # which will be combined in a least squares sense

    shape_grid = LinearGrid(
        M=int(eq.M), N=int(eq.N), rho=np.array([1.0]), NFP=eq.NFP, sym=True, axis=False
    )

    ntransits = 8

    zeta_field_line = np.linspace(0, 2 * np.pi * ntransits, 64 * ntransits)
    alpha = jnp.array([0.0])
    rho = jnp.linspace(0.85, 1.0, 2)
    # rho = np.linspace(0.85, 1.0, 2)
    flux_surface_grid = LinearGrid(
        rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP
    )

    objective = ObjectiveFunction(
        (
            GammaC(
                eq=eq,
                rho=rho,
                alpha=alpha,
                deriv_mode="fwd",
                batch=False,
                weight=1e3,
                Nemov = False,
            ),
            Elongation(eq=eq, grid=shape_grid,target=1),#0 bounds=(0.5, 2.0), weight=1e3),
            GenericObjective(
                f="curvature_k2_rho",
                thing=eq,
                grid=shape_grid,
                bounds=(-75, 15),
                weight=2e3,
            ),
        ),
    )
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(
            eq,
            grid=ConcentricGrid(
                L=round(2 * eq.L),
                M=round(1.5 * eq.M),
                N=round(1.5 * eq.N),
                NFP=eq.NFP,
                sym=eq.sym,
            ),
        ),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixCurrent(eq=eq),
        FixPsi(eq=eq),
    )
    # this is the default optimizer, which re-solves the equilibrium at each step
    optimizer = Optimizer("proximal-lsq-exact")          
    eq_new, result = optimizer.optimize(
        things = eq,
        objective=objective,
        constraints=constraints,
        maxiter=3,  # we don't need to solve to optimality at each multigrid step
        verbose=3,
        copy=True,  # don't modify original, return a new optimized copy
        options={
            # Sometimes the default initial trust radius is too big, allowing the
            # optimizer to take too large a step in a bad direction. If this happens,
            # we can manually specify a smaller starting radius. Each optimizer has a
            # number of different options that can be used to tune the performance.
            # See the documentation for more info.
            "initial_trust_ratio": 1e-2,
            "maxiter": 125,
            "ftol": 1e-3,
            "xtol": 1e-8,
        },
    )
    eq_new = eq_new[0]
   
    return eq_new 

eq = get("ESTELL")
for k in np.arange(1, eq.M + 1, 1):
    if not eq.is_nested():
        print("NOT NESTED")
        assert eq.is_nested()
        break
    jax.clear_caches()
    eq = run_opt_step(k, eq)

Error:

ValueError                                Traceback (most recent call last)
Cell In[1], line 137
    135     break
    136 jax.clear_caches()
--> 137 eq = run_opt_step(k, eq)

Cell In[1], line 107, in run_opt_step(k, eq)
    103 optimizer = Optimizer("proximal-lsq-exact")
    105 print("spot 1:", type(eq))
--> 107 eq_new, result = optimizer.optimize(
    108     things = eq,
    109     objective=objective,
    110     constraints=constraints,
    111     maxiter=3,  # we don't need to solve to optimality at each multigrid step
    112     verbose=3,
    113     copy=True,  # don't modify original, return a new optimized copy
    114     options={
    115         # Sometimes the default initial trust radius is too big, allowing the
    116         # optimizer to take too large a step in a bad direction. If this happens,
    117         # we can manually specify a smaller starting radius. Each optimizer has a
    118         # number of different options that can be used to tune the performance.
    119         # See the documentation for more info.
    120         "initial_trust_ratio": 1e-2,
    121         "maxiter": 125,
    122         "ftol": 1e-3,
    123         "xtol": 1e-8,
    124     },
    125 )
    126 eq_new = eq_new[0]
    128 return eq_new

File ~/DESC/desc/optimize/optimizer.py:311, in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    307     print("Using method: " + str(self.method))
    309 timer.start("Solution time")
--> 311 result = optimizers[method]["fun"](
    312     objective,
    313     nonlinear_constraint,
    314     x0,
    315     method,
    316     x_scale,
    317     verbose,
    318     stoptol,
    319     options,
    320 )
    322 if isinstance(objective, LinearConstraintProjection):
    323     # remove wrapper to get at underlying objective
    324     result["allx"] = [objective.recover(x) for x in result["allx"]]

File ~/DESC/desc/optimize/_desc_wrappers.py:270, in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File ~/DESC/desc/optimize/least_squares.py:176, in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    173 assert in_bounds(x, lb, ub), "x0 is infeasible"
    174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
    177 nfev += 1
    178 cost = 0.5 * jnp.dot(f, f)

File ~/DESC/desc/optimize/_constraint_wrappers.py:224, in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
    208 """Compute the objective function and apply weighting / bounds.
    209 
    210 Parameters
   (...)
    221 
    222 """
    223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
    225 return f

File ~/DESC/desc/optimize/_constraint_wrappers.py:843, in ProximalProjection.compute_scaled_error(self, x, constants)
    841 constants = setdefault(constants, self.constants)
    842 xopt, _ = self._update_equilibrium(x, store=False)
--> 843 return self._objective.compute_scaled_error(xopt, constants[0])

    [... skipping hidden 6 frame]

File ~/.conda/envs/desc-env-latest/lib/python3.11/site-packages/jax/_src/pjit.py:1339, in seen_attrs_get(fun, in_type)
   1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
   1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

dpanici avatar Oct 02 '24 20:10 dpanici