Interactively step through solve with jax.lax.scan
Hi,
I am trying to use jax.lax.scan together with the interactive solving approach (example in documentation) for some minimization. However I found that for some solvers (least_squares) using lax.scan results in a type error. (Everything works with the standard for-loop.)
This is the example im working with:
import jax.numpy as jnp
import optimistix
from jax.tree_util import Partial
import jax
import numpy as np
### work with lax.scan
# solver = optimistix.BFGS(rtol=1e-3, atol=1e-3)
# solver = optimistix.NonlinearCG(rtol=1e-3, atol=1e-3)
### do NOT work with lax.scan
# solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)
# solver = optimistix.LevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.IndirectLevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.NelderMead(rtol=1e-3, atol=1e-3)
# solver = optimistix.Dogleg(rtol=1e-3, atol=1e-3)
def test_func(x, *args):
return jnp.sum(x**2) + 1.0, None
fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))
args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()
state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
def step_helper(carry, xs):
y, state, _ = carry
return step(y=y, state=state), None
carry = (y, state, None)
carry, _ = jax.lax.scan(step_helper, carry, length=10)
#for _ in range(10):
# carry, _ = step_helper(carry, None)
y, state, aux = carry
The error which is raised is:
TypeError: Value { lambda a:f32[5,5]; b:f32[5,5]. let c:f32[5,5] = mul b a d:f32[] = reduce_sum[axes=(0, 1)] c in (d,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type
Am I doing something wrong or is there something else going on?
Hi Matilda,
some solvers, such as the GaussNewton family of least-squares solvers, have a jaxpr as part of their state (this is the language in which compiled JAX programs are expressed). In the least-squares solvers, the residual Jacobian contains a jaxpr. These are not regular PyTrees, and are therefore incompatible with jax.lax.scan.
You can partition the state such that only the dynamic elements are part of carry, and the static stuff (which does not change from step to step) is closed over by your step_helper function.
The following code works:
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix
from jax.tree_util import Partial
solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)
def test_func(x, *args):
return jnp.sum(x**2) + 1.0, None
fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))
args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()
state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
dynamic, static = eqx.partition(state, eqx.is_array)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
def step_helper(carry, xs):
y, dynamic, _ = carry
state = eqx.combine(dynamic, static)
y, state, aux = step(y=y, state=state)
dynamic, _ = eqx.partition(state, eqx.is_array)
carry = (y, dynamic, aux)
return carry, None
carry = (y, dynamic, None)
carry, _ = jax.lax.scan(step_helper, carry, length=10)
y, state, aux = carry
NelderMead worked for me using your MWE code.
Hey,
thanks for the help. Everything works now.
NelderMead also originally worked for me. I labelled it as "not working" by mistake.