Newton versus Dogleg/LM with partial jit
I'm using Jax 0.5.2, equinox 0.11.12, optax 0.2.4.
I am trying to pass optimistix solvers into a partially jitted function. Newton works perfectly, but DogLeg/LM are throwing the same error. I would have expected that I could hot-swap solvers. I can construct a MWE to probe this further but I was just hoping someone might immediately identify the issue:
File ~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1550, in get_aval(x)
[1548](https://file+.vscode-resource.vscode-cdn.net/Users/dan/Documents/academic/explanetology/atmodeller/notebooks/~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1548) if hasattr(x, '__jax_array__'):
[1549](https://file+.vscode-resource.vscode-cdn.net/Users/dan/Documents/academic/explanetology/atmodeller/notebooks/~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1549) return get_aval(x.__jax_array__())
-> [1550](https://file+.vscode-resource.vscode-cdn.net/Users/dan/Documents/academic/explanetology/atmodeller/notebooks/~/Documents/academic/explanetology/atmodeller/.venv/lib/python3.10/site-packages/jax/_src/core.py:1550) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
TypeError: Argument '{ lambda a:f64[15,15]; b:f64[15]. let
c:f64[15] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float64
] a b
in (c,) }' of type '<class 'jax._src.core.Jaxpr'>' is not a valid JAX type
The function that gets my wrapped partially jitted solver:
def get_wrapped_jit_solver(self) -> Callable:
"""Gets the jit solver with fixed and solver parameters set.
Returns:
jit solver with fixed and solver parameters set
"""
fixed_parameters: FixedParameters = self.get_fixed_parameters(
self.solution_args.fugacity_constraints, self.solution_args.mass_constraints
)
solver_parameters: SolverParameters = self.solution_args.solver_parameters
# Newton works!
solver = optx.Newton(1e-3, 1e-3, optx.rms_norm)
# LevenbergMarquardt/Dogleg doesn't
# solver = optx.Dogleg(1e-3, 1e-3, optx.rms_norm)
def wrapped_jit_solver(
solution: Solution,
traced_parameters: TracedParameters,
) -> Callable:
# Generate a fresh key everytime
rng_key = random.key(random.randint(random.PRNGKey(0), (1,), 0, 2**32)[0])
return solve(
solution, traced_parameters, rng_key, fixed_parameters, solver_parameters, solver
)
return wrapped_jit_solver
where the top of solve looks like:
@partial(jit, static_argnames=["fixed_parameters", "solver_parameters", "solver"])
def solve(
solution: Solution,
traced_parameters: TracedParameters,
rng_key: Array,
fixed_parameters: FixedParameters,
solver_parameters: SolverParameters,
solver,
) -> optx.Solution:
"""Solves the system of non-linear equations
Args:
solution: Solution
traced_parameters: Traced parameters
rng_key: Random number generating key
fixed_parameters: Fixed parameters
solver_parameters: Solver parameters
Returns:
The solution
"""
options: dict[str, Any] = {
"lower": np.asarray(solver_parameters.lower),
"upper": np.asarray(solver_parameters.upper),
"jac": solver_parameters.jac,
}
# First solver attempt (without perturbation)
sol = optx.root_find(
objective_function,
solver,
solution.data,
args={"traced_parameters": traced_parameters, "fixed_parameters": fixed_parameters},
throw=solver_parameters.throw,
max_steps=solver_parameters.max_steps,
options=options,
)
It looks like some part of your code is trying to jit some of its 'static' (non-array) arguments.
Hard to say exactly which bit without a MWE, but at the very least you should be able to sidestep the need for that by simply using equinox.filter_jit everywhere -- essentially as a more-ergonomic jax.jit that avoids the need to manually determine what is an array and what isn't.
Thanks @patrick-kidger . So the origin of this problem is that I was trying to return an Optimistix Solution instance from a partially jitted function. If I just return the arrays that I need (like sol.value and sol.result._value) then all works as intended. I guess I assumed it would be OK to return a Solution instance but maybe not?
I suspect this may be due to the fact that you have a jaxpr in the Gauss Newton solver state, which is part of the solution.
Ah yup, that'll be it -- for exactly the reason Johanna gave. jax.jit (but not eqx.filter_jit) will try to cast all its outputs to arrays.