optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Interactively step through solve with jax.lax.scan

Open matillda123 opened this issue 1 year ago • 2 comments

Hi,

I am trying to use jax.lax.scan together with the interactive solving approach (example in documentation) for some minimization. However I found that for some solvers (least_squares) using lax.scan results in a type error. (Everything works with the standard for-loop.)

This is the example im working with:

import jax.numpy as jnp
import optimistix
from jax.tree_util import Partial
import jax
import numpy as np

### work with lax.scan
# solver = optimistix.BFGS(rtol=1e-3, atol=1e-3)
# solver = optimistix.NonlinearCG(rtol=1e-3, atol=1e-3)

### do NOT work with lax.scan
# solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)
# solver = optimistix.LevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.IndirectLevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.NelderMead(rtol=1e-3, atol=1e-3)
# solver = optimistix.Dogleg(rtol=1e-3, atol=1e-3)



def test_func(x, *args):
    return jnp.sum(x**2) + 1.0, None


fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))

args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)

def step_helper(carry, xs):
    y, state, _ = carry
    return step(y=y, state=state), None

carry = (y, state, None)

carry, _ = jax.lax.scan(step_helper, carry, length=10)

#for _ in range(10):
#    carry, _ = step_helper(carry, None)

y, state, aux = carry

The error which is raised is:

TypeError: Value { lambda a:f32[5,5]; b:f32[5,5]. let c:f32[5,5] = mul b a d:f32[] = reduce_sum[axes=(0, 1)] c in (d,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type

Am I doing something wrong or is there something else going on?

matillda123 avatar Dec 31 '24 11:12 matillda123

Hi Matilda,

some solvers, such as the GaussNewton family of least-squares solvers, have a jaxpr as part of their state (this is the language in which compiled JAX programs are expressed). In the least-squares solvers, the residual Jacobian contains a jaxpr. These are not regular PyTrees, and are therefore incompatible with jax.lax.scan.

You can partition the state such that only the dynamic elements are part of carry, and the static stuff (which does not change from step to step) is closed over by your step_helper function.

The following code works:

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix
from jax.tree_util import Partial


solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)

def test_func(x, *args):
    return jnp.sum(x**2) + 1.0, None

fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))

args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
dynamic, static = eqx.partition(state, eqx.is_array)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)

def step_helper(carry, xs):
    
    y, dynamic, _ = carry
    state = eqx.combine(dynamic, static)
    y, state, aux = step(y=y, state=state)
    dynamic, _ = eqx.partition(state, eqx.is_array)
    
    carry = (y, dynamic, aux)
    return carry, None

carry = (y, dynamic, None)
carry, _ = jax.lax.scan(step_helper, carry, length=10)
y, state, aux = carry

NelderMead worked for me using your MWE code.

johannahaffner avatar Jan 01 '25 14:01 johannahaffner

Hey,

thanks for the help. Everything works now. NelderMead also originally worked for me. I labelled it as "not working" by mistake.

matillda123 avatar Jan 02 '25 13:01 matillda123