jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Different results between jaxopt.LBFGS and jaxopt.ScipyMinimize(method - 'l-bfgs-b')

Open richinex opened this issue 3 years ago • 7 comments

I am trying to run a complex nonlinear optimization on a multi-dimensional data using vmap on the solver.run. Since I could not use the l-bfgs-b method in the ScipyMinimize wrapper, I resorted to the jaxopt.LBFGS. However I realized that the result from the latter was not correct. I would like to know why and what I could do. My minimal working example is shown below. Thanks

# Load libraries
import numpy as np
import scipy.sparse as sps
import jax
import jaxopt
from jax import value_and_grad
from jax import numpy as jnp
from jax.example_libraries import optimizers as jax_opt
jax.config.update("jax_enable_x64", True)

# Make data
F_arr =jnp.array([4.000000e+00, 1.250000e+01, 2.100000e+01, 2.950000e+01,3.800000e+01, 4.650000e+01, 6.350000e+01, 8.050000e+01,
        1.230000e+02, 1.570000e+02, 1.995000e+02, 2.505000e+02,3.185000e+02, 3.940000e+02, 4.965000e+02, 6.260000e+02,7.925000e+02, 9.975000e+02, 1.256000e+03, 1.581000e+03,
        1.990500e+03, 2.506000e+03, 3.155000e+03, 3.971500e+03,5.000000e+03, 6.294500e+03, 7.924500e+03, 9.976500e+03,1.255950e+04, 1.581150e+04, 1.990550e+04, 2.505950e+04,
        3.154800e+04, 3.971650e+04, 5.000000e+04])


Y_arr = jnp.array([0.00495074+0.00290374j, 0.00724701+0.00289439j,0.00821288+0.00279885j, 0.00877054+0.00276919j,
        0.00921332+0.0027551j , 0.00953043+0.00274739j,0.01002155+0.00274946j, 0.01038829+0.00279736j,
        0.01103745+0.00293741j, 0.01143682+0.00304808j,0.01185019+0.00321095j, 0.01222892+0.00340771j,
        0.01264666+0.00365856j, 0.01312294+0.00390083j,0.01356835+0.00423682j, 0.01414305+0.00459166j,
        0.01475416+0.00502188j, 0.01544523+0.0054795j ,0.01620464+0.00597393j, 0.01707565+0.00650766j,
        0.01800564+0.00707323j, 0.01907494+0.00766403j,0.0202539 +0.00824607j, 0.02156295+0.00882627j,
        0.02293967+0.0093636j , 0.02446602+0.00988404j,0.02606663+0.01034258j, 0.02778773+0.01073912j,
        0.0295645 +0.01105176j, 0.03142458+0.01130524j,0.03332406+0.01142638j, 0.03529196+0.01141756j,
        0.03725344+0.01128458j, 0.03917468+0.01100424j,0.04104471+0.0105539j], dtype='complex64')


sigma_arr = jnp.array([2.43219802e-06, 3.84912892e-06, 4.65468565e-06, 5.23176095e-06,5.68508176e-06, 6.05872401e-06, 6.64642994e-06, 7.11385064e-06,
        7.95151846e-06, 8.43719499e-06, 8.92535354e-06, 9.37367440e-06,9.85436691e-06, 1.02790955e-05, 1.07571723e-05, 1.12416874e-05,
        1.17638756e-05, 1.22902720e-05, 1.28422944e-05, 1.34043157e-05,1.39690355e-05, 1.45196518e-05, 1.50516798e-05, 1.55538437e-05,
        1.60391119e-05, 1.65177389e-05, 1.69736650e-05, 1.74361339e-05,1.78881437e-05, 1.83461307e-05, 1.87868263e-05, 1.92436037e-05,
        1.96903675e-05, 2.01275998e-05, 2.05546330e-05])

params_init = jnp.array([1.84285135e+01, 1.71039097e-05, 6.98550706e-01,
             7.33632243e-01, 4.77681912e+02, 5.65632259e-04,
             1.34721147e+01, 1.34025052e+02, 3.93700063e+00,
             2.96283162e-01, 2.31503009e-01])

parameter_bounds = [[1e-1,  1e6], [1e-7, 1e-1], [1e-1, 1], [1e-1, 1e7], [1e-1,  1e7], [1e-7, 1e-1], [1e-1,  1e6], [1e-1,  1e7], [1e-1,  1e7], [1e-1, 1], [1e-1,  1e7]]
lb = jnp.array([i[0] for i in parameter_bounds])
ub = jnp.array([i[1] for i in parameter_bounds])

n_par = len(params_init)
n_data = 5
n_freq = len(F_arr)

# form a matrix from the initial parameters and the bounds
par_mat = jnp.broadcast_to(params_init[:,None], (len(params_init), n_data))
lb_mat = jnp.broadcast_to(lb[:,None], (len(params_init), n_data))
ub_mat = jnp.broadcast_to(ub[:,None], (len(params_init), n_data))

# convert external to internal parameters
par_log = jnp.log10((par_mat - lb_mat) / (1-par_mat / ub_mat))

# create a matrix from F_arr, Y_arr and sigma_arr
F = jnp.broadcast_to(F_arr[:,None], (n_freq, n_data))
Y =  jnp.broadcast_to(Y_arr[:,None], (n_freq, n_data))
sigma_Y = jnp.broadcast_to(sigma_arr[:,None], (n_freq, n_data))

# Define model
@jax.jit
def fun(p, f):
    w = 2*jnp.pi*f
    Rs = p[0]
    Qh = p[1]*p[10]
    nh = p[2]
    Rad = p[3]/p[10]
    Wad = p[4]/p[10]
    Cad = p[5]*p[10]
    Rint = p[6]/p[10]
    Wint = p[7]/p[10]
    tau = p[8]
    alpha = p[9]
    Rp = p[10]
    Ct = (1/Cad)**-1
    Zad = Rad + Wad/jnp.sqrt(1j*w)
    Zint = Rint + Wint/((1j*w*tau)**(alpha/2)) * 1/(jnp.tanh((1j*w*tau)**(alpha/2)))
    Yf = (Zad + (1j*w*Ct)**-1)/(Zad*Zint + (Zad+Zint)*(1j*w*Ct)**-1)
    Ydl = Qh*((1j*w)**nh)
    Kl = jnp.sqrt(Ydl + Yf)
    Z = Rs + Rp * jnp.tanh(Kl)**-1 / Kl 
    Y = 1/Z 
    return jnp.concatenate([Y.real, Y.imag], axis = 0)




# sum of squares residual
@jax.jit
def obj_fun(p, x, y, yerr):
    ndata = len(x)
    dof = (2*ndata-(len(p)))
    y_concat = jnp.concatenate([y.real, y.imag], axis = 0)
    sigma = jnp.concatenate([yerr,yerr], axis = 0)
    y_model = fun(p, x)
    # chi_sqr = ((jnp.abs((1/sigma) * (y_concat - y_model))))
    chi_sqr = jnp.linalg.norm(((y_concat - y_model)/sigma))**2
    return (chi_sqr)

# Multidimensional cost function
@jax.jit
def cost_fun(P, X, Y, YERR, LB, UB):
    dof = (2*len(X[0])*len(X))-len(P)
    P_norm = (LB + 10**P) / (1 + 10**P / UB)
    chi = jax.vmap(obj_fun, in_axes=1)(P_norm, X, Y, YERR)
    return jnp.sum(chi) / dof

# Run the optimization
solver_1 = jaxopt.ScipyMinimize(method = "l-bfgs-b", fun=cost_fun, tol = 1e-12, options ={'maxiter':5000})
solver_1_sol = solver_1.run(par_log, F, Y, sigma_Y, lb_mat, ub_mat)
solver_1_sol.params[:, 0]

# Correct result
# DeviceArray([ 1.16801937e+00, -4.40860838e+00,  1.70230450e-01,
#              -1.91683037e+00,  2.81482271e+00, -3.33403710e+00,
#               1.58763377e+00,  1.96265752e+00, -9.80790594e+02,
#              -4.01386155e-01,  1.04828148e+00], dtype=float64)


solver_2 = jaxopt.LBFGS(fun=cost_fun, maxiter = 5000)
solver_2_sol = solver_2.run(par_log, X=F, Y=Y, YERR=sigma_Y, LB=lb_mat, UB=ub_mat)
solver_2_sol.params[:, 0]

# # Incorrect result
# DeviceArray([ 5.13495057e+04, -1.12978597e+04, -1.15845934e+04,
#               6.74886483e+01,  2.41560694e+03, -1.42666112e+04,
#               3.63855005e+04,  1.28521409e+05, -1.94682351e+04,
#              -7.21753278e+04,  1.21914493e+02], dtype=float64)

richinex avatar Jul 14 '22 18:07 richinex

Could you try also with jaxopt.LBFGS(..., linesearch="zoom")?

jaxopt.LBFGS and LBFGS-B from SciPy don't use the same line search technique so it's possible that we don't get the same results sometimes, if the function to be minimized is nonconvex.

mblondel avatar Jul 17 '22 08:07 mblondel

I also did try with the zoom line search and did not get the correct results. You're right the problem is nonconvex.

richinex avatar Jul 17 '22 08:07 richinex

Nevertheless, I found that I could use list(map(func, *args)) instead of vmap with jaxopt.scipy.minimize and it temporarily solves my problem

richinex avatar Jul 17 '22 10:07 richinex

@richinex could be related to this, and therefore with the fixes in https://github.com/google/jaxopt/pull/323 and https://github.com/google/jaxopt/pull/350 it might make the results consistent with that of core JAX. Maybe you can give it a shot.

zaccharieramzi avatar Dec 08 '22 23:12 zaccharieramzi