jaxopt
jaxopt copied to clipboard
LevenbergMarquardt do not seems to work with non-flat input.
GaussNewton is working as intended with Pytree. I would expect the same for LM. Instead, I had to flatten the array to make it properly works.
The errors appear at line 445. The error comes from the fact that the pytree of params and vec do not match.
Hi @bolducke thanks for reporting this. Do you have an example that I can use for the repro? I plan to update the unit tests to cover that use case for both GN and LM.
@amir-saadat I was going to report this as well. I wrote a super simple test case to demonstrate, though I'm not sure its what you're looking for.
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxopt
M = 5
params = jnp.zeros((2,M))
params = params.at[0].set(jnp.arange(M) * 1.0)
params = params.at[1].set(jnp.arange(M)**2 * 1.0)
params_dict = {'A': jnp.arange(M) * 1.0, 'B': jnp.arange(M)**2 * 1.0}
def F(params):
return jnp.asarray([jnp.sum(params[0]), jnp.sum(params[0] * params[1]**2)])
def F_dict(params):
return jnp.asarray([jnp.sum(params['A']), jnp.sum(params['A'] * params['B']**2)])
def optimize_F_gn(params, F):
gn = jaxopt.GaussNewton(residual_fun=F)
return gn.run(params).params
def optimize_F_lm(params, F):
gn = jaxopt.LevenbergMarquardt(residual_fun=F)
return gn.run(params).params
print(params)
print(optimize_F_gn(params, F))
print(optimize_F_gn(params_dict, F_dict))
print(optimize_F_lm(params, F)) # fails
print(optimize_F_lm(params_dict, F_dict)) # fails