jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

LevenbergMarquardt do not seems to work with non-flat input.

Open bolducke opened this issue 1 year ago • 2 comments

GaussNewton is working as intended with Pytree. I would expect the same for LM. Instead, I had to flatten the array to make it properly works.

image

The errors appear at line 445. The error comes from the fact that the pytree of params and vec do not match.

bolducke avatar Aug 22 '23 15:08 bolducke

Hi @bolducke thanks for reporting this. Do you have an example that I can use for the repro? I plan to update the unit tests to cover that use case for both GN and LM.

amir-saadat avatar Sep 08 '23 03:09 amir-saadat

@amir-saadat I was going to report this as well. I wrote a super simple test case to demonstrate, though I'm not sure its what you're looking for.

config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxopt

M = 5
params = jnp.zeros((2,M))
params = params.at[0].set(jnp.arange(M) * 1.0)
params = params.at[1].set(jnp.arange(M)**2 * 1.0)
params_dict = {'A': jnp.arange(M) * 1.0, 'B': jnp.arange(M)**2 * 1.0}

def F(params):
	return jnp.asarray([jnp.sum(params[0]), jnp.sum(params[0] * params[1]**2)])

def F_dict(params):
	return jnp.asarray([jnp.sum(params['A']), jnp.sum(params['A'] * params['B']**2)])

def optimize_F_gn(params, F):
	gn = jaxopt.GaussNewton(residual_fun=F)
	return gn.run(params).params

def optimize_F_lm(params, F):
	gn = jaxopt.LevenbergMarquardt(residual_fun=F)
	return gn.run(params).params

print(params)
print(optimize_F_gn(params, F)) 
print(optimize_F_gn(params_dict, F_dict))
print(optimize_F_lm(params, F)) # fails
print(optimize_F_lm(params_dict, F_dict)) # fails

nickmcgreivy avatar Jan 22 '24 01:01 nickmcgreivy