optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Newton versus Dogleg/LM with partial jit

Open djbower opened this issue 10 months ago • 4 comments

I'm using Jax 0.5.2, equinox 0.11.12, optax 0.2.4.

I am trying to pass optimistix solvers into a partially jitted function. Newton works perfectly, but DogLeg/LM are throwing the same error. I would have expected that I could hot-swap solvers. I can construct a MWE to probe this further but I was just hoping someone might immediately identify the issue:

File ~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1550, in get_aval(x)
   [1548](https://file+.vscode-resource.vscode-cdn.net/Users/dan/Documents/academic/explanetology/atmodeller/notebooks/~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1548) if hasattr(x, '__jax_array__'):
   [1549](https://file+.vscode-resource.vscode-cdn.net/Users/dan/Documents/academic/explanetology/atmodeller/notebooks/~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1549)   return get_aval(x.__jax_array__())
-> [1550](https://file+.vscode-resource.vscode-cdn.net/Users/dan/Documents/academic/explanetology/atmodeller/notebooks/~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1550) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")

TypeError: Argument '{ lambda a:f64[15,15]; b:f64[15]. let
    c:f64[15] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float64
    ] a b
  in (c,) }' of type '<class 'jax._src.core.Jaxpr'>' is not a valid JAX type

The function that gets my wrapped partially jitted solver:

    def get_wrapped_jit_solver(self) -> Callable:
        """Gets the jit solver with fixed and solver parameters set.

        Returns:
            jit solver with fixed and solver parameters set
        """
        fixed_parameters: FixedParameters = self.get_fixed_parameters(
            self.solution_args.fugacity_constraints, self.solution_args.mass_constraints
        )
        solver_parameters: SolverParameters = self.solution_args.solver_parameters

        # Newton works!
        solver = optx.Newton(1e-3, 1e-3, optx.rms_norm)

        # LevenbergMarquardt/Dogleg doesn't
        # solver = optx.Dogleg(1e-3, 1e-3, optx.rms_norm)

        def wrapped_jit_solver(
            solution: Solution,
            traced_parameters: TracedParameters,
        ) -> Callable:
            # Generate a fresh key everytime
            rng_key = random.key(random.randint(random.PRNGKey(0), (1,), 0, 2**32)[0])
            return solve(
                solution, traced_parameters, rng_key, fixed_parameters, solver_parameters, solver
            )

        return wrapped_jit_solver

where the top of solve looks like:

@partial(jit, static_argnames=["fixed_parameters", "solver_parameters", "solver"])
def solve(
    solution: Solution,
    traced_parameters: TracedParameters,
    rng_key: Array,
    fixed_parameters: FixedParameters,
    solver_parameters: SolverParameters,
    solver,
) -> optx.Solution:
    """Solves the system of non-linear equations

    Args:
        solution: Solution
        traced_parameters: Traced parameters
        rng_key: Random number generating key
        fixed_parameters: Fixed parameters
        solver_parameters: Solver parameters

    Returns:
        The solution
    """
    options: dict[str, Any] = {
        "lower": np.asarray(solver_parameters.lower),
        "upper": np.asarray(solver_parameters.upper),
        "jac": solver_parameters.jac,
    }

    # First solver attempt (without perturbation)
    sol = optx.root_find(
        objective_function,
        solver,
        solution.data,
        args={"traced_parameters": traced_parameters, "fixed_parameters": fixed_parameters},
        throw=solver_parameters.throw,
        max_steps=solver_parameters.max_steps,
        options=options,
    )

djbower avatar Mar 08 '25 14:03 djbower

It looks like some part of your code is trying to jit some of its 'static' (non-array) arguments.

Hard to say exactly which bit without a MWE, but at the very least you should be able to sidestep the need for that by simply using equinox.filter_jit everywhere -- essentially as a more-ergonomic jax.jit that avoids the need to manually determine what is an array and what isn't.

patrick-kidger avatar Mar 08 '25 16:03 patrick-kidger

Thanks @patrick-kidger . So the origin of this problem is that I was trying to return an Optimistix Solution instance from a partially jitted function. If I just return the arrays that I need (like sol.value and sol.result._value) then all works as intended. I guess I assumed it would be OK to return a Solution instance but maybe not?

djbower avatar Mar 12 '25 07:03 djbower

I suspect this may be due to the fact that you have a jaxpr in the Gauss Newton solver state, which is part of the solution.

johannahaffner avatar Mar 12 '25 08:03 johannahaffner

Ah yup, that'll be it -- for exactly the reason Johanna gave. jax.jit (but not eqx.filter_jit) will try to cast all its outputs to arrays.

patrick-kidger avatar Mar 12 '25 12:03 patrick-kidger