jaxopt
jaxopt copied to clipboard
Different results between jaxopt.LBFGS and jaxopt.ScipyMinimize(method - 'l-bfgs-b')
I am trying to run a complex nonlinear optimization on a multi-dimensional data using vmap on the solver.run. Since I could not use the l-bfgs-b method in the ScipyMinimize wrapper, I resorted to the jaxopt.LBFGS. However I realized that the result from the latter was not correct. I would like to know why and what I could do. My minimal working example is shown below. Thanks
# Load libraries
import numpy as np
import scipy.sparse as sps
import jax
import jaxopt
from jax import value_and_grad
from jax import numpy as jnp
from jax.example_libraries import optimizers as jax_opt
jax.config.update("jax_enable_x64", True)
# Make data
F_arr =jnp.array([4.000000e+00, 1.250000e+01, 2.100000e+01, 2.950000e+01,3.800000e+01, 4.650000e+01, 6.350000e+01, 8.050000e+01,
1.230000e+02, 1.570000e+02, 1.995000e+02, 2.505000e+02,3.185000e+02, 3.940000e+02, 4.965000e+02, 6.260000e+02,7.925000e+02, 9.975000e+02, 1.256000e+03, 1.581000e+03,
1.990500e+03, 2.506000e+03, 3.155000e+03, 3.971500e+03,5.000000e+03, 6.294500e+03, 7.924500e+03, 9.976500e+03,1.255950e+04, 1.581150e+04, 1.990550e+04, 2.505950e+04,
3.154800e+04, 3.971650e+04, 5.000000e+04])
Y_arr = jnp.array([0.00495074+0.00290374j, 0.00724701+0.00289439j,0.00821288+0.00279885j, 0.00877054+0.00276919j,
0.00921332+0.0027551j , 0.00953043+0.00274739j,0.01002155+0.00274946j, 0.01038829+0.00279736j,
0.01103745+0.00293741j, 0.01143682+0.00304808j,0.01185019+0.00321095j, 0.01222892+0.00340771j,
0.01264666+0.00365856j, 0.01312294+0.00390083j,0.01356835+0.00423682j, 0.01414305+0.00459166j,
0.01475416+0.00502188j, 0.01544523+0.0054795j ,0.01620464+0.00597393j, 0.01707565+0.00650766j,
0.01800564+0.00707323j, 0.01907494+0.00766403j,0.0202539 +0.00824607j, 0.02156295+0.00882627j,
0.02293967+0.0093636j , 0.02446602+0.00988404j,0.02606663+0.01034258j, 0.02778773+0.01073912j,
0.0295645 +0.01105176j, 0.03142458+0.01130524j,0.03332406+0.01142638j, 0.03529196+0.01141756j,
0.03725344+0.01128458j, 0.03917468+0.01100424j,0.04104471+0.0105539j], dtype='complex64')
sigma_arr = jnp.array([2.43219802e-06, 3.84912892e-06, 4.65468565e-06, 5.23176095e-06,5.68508176e-06, 6.05872401e-06, 6.64642994e-06, 7.11385064e-06,
7.95151846e-06, 8.43719499e-06, 8.92535354e-06, 9.37367440e-06,9.85436691e-06, 1.02790955e-05, 1.07571723e-05, 1.12416874e-05,
1.17638756e-05, 1.22902720e-05, 1.28422944e-05, 1.34043157e-05,1.39690355e-05, 1.45196518e-05, 1.50516798e-05, 1.55538437e-05,
1.60391119e-05, 1.65177389e-05, 1.69736650e-05, 1.74361339e-05,1.78881437e-05, 1.83461307e-05, 1.87868263e-05, 1.92436037e-05,
1.96903675e-05, 2.01275998e-05, 2.05546330e-05])
params_init = jnp.array([1.84285135e+01, 1.71039097e-05, 6.98550706e-01,
7.33632243e-01, 4.77681912e+02, 5.65632259e-04,
1.34721147e+01, 1.34025052e+02, 3.93700063e+00,
2.96283162e-01, 2.31503009e-01])
parameter_bounds = [[1e-1, 1e6], [1e-7, 1e-1], [1e-1, 1], [1e-1, 1e7], [1e-1, 1e7], [1e-7, 1e-1], [1e-1, 1e6], [1e-1, 1e7], [1e-1, 1e7], [1e-1, 1], [1e-1, 1e7]]
lb = jnp.array([i[0] for i in parameter_bounds])
ub = jnp.array([i[1] for i in parameter_bounds])
n_par = len(params_init)
n_data = 5
n_freq = len(F_arr)
# form a matrix from the initial parameters and the bounds
par_mat = jnp.broadcast_to(params_init[:,None], (len(params_init), n_data))
lb_mat = jnp.broadcast_to(lb[:,None], (len(params_init), n_data))
ub_mat = jnp.broadcast_to(ub[:,None], (len(params_init), n_data))
# convert external to internal parameters
par_log = jnp.log10((par_mat - lb_mat) / (1-par_mat / ub_mat))
# create a matrix from F_arr, Y_arr and sigma_arr
F = jnp.broadcast_to(F_arr[:,None], (n_freq, n_data))
Y = jnp.broadcast_to(Y_arr[:,None], (n_freq, n_data))
sigma_Y = jnp.broadcast_to(sigma_arr[:,None], (n_freq, n_data))
# Define model
@jax.jit
def fun(p, f):
w = 2*jnp.pi*f
Rs = p[0]
Qh = p[1]*p[10]
nh = p[2]
Rad = p[3]/p[10]
Wad = p[4]/p[10]
Cad = p[5]*p[10]
Rint = p[6]/p[10]
Wint = p[7]/p[10]
tau = p[8]
alpha = p[9]
Rp = p[10]
Ct = (1/Cad)**-1
Zad = Rad + Wad/jnp.sqrt(1j*w)
Zint = Rint + Wint/((1j*w*tau)**(alpha/2)) * 1/(jnp.tanh((1j*w*tau)**(alpha/2)))
Yf = (Zad + (1j*w*Ct)**-1)/(Zad*Zint + (Zad+Zint)*(1j*w*Ct)**-1)
Ydl = Qh*((1j*w)**nh)
Kl = jnp.sqrt(Ydl + Yf)
Z = Rs + Rp * jnp.tanh(Kl)**-1 / Kl
Y = 1/Z
return jnp.concatenate([Y.real, Y.imag], axis = 0)
# sum of squares residual
@jax.jit
def obj_fun(p, x, y, yerr):
ndata = len(x)
dof = (2*ndata-(len(p)))
y_concat = jnp.concatenate([y.real, y.imag], axis = 0)
sigma = jnp.concatenate([yerr,yerr], axis = 0)
y_model = fun(p, x)
# chi_sqr = ((jnp.abs((1/sigma) * (y_concat - y_model))))
chi_sqr = jnp.linalg.norm(((y_concat - y_model)/sigma))**2
return (chi_sqr)
# Multidimensional cost function
@jax.jit
def cost_fun(P, X, Y, YERR, LB, UB):
dof = (2*len(X[0])*len(X))-len(P)
P_norm = (LB + 10**P) / (1 + 10**P / UB)
chi = jax.vmap(obj_fun, in_axes=1)(P_norm, X, Y, YERR)
return jnp.sum(chi) / dof
# Run the optimization
solver_1 = jaxopt.ScipyMinimize(method = "l-bfgs-b", fun=cost_fun, tol = 1e-12, options ={'maxiter':5000})
solver_1_sol = solver_1.run(par_log, F, Y, sigma_Y, lb_mat, ub_mat)
solver_1_sol.params[:, 0]
# Correct result
# DeviceArray([ 1.16801937e+00, -4.40860838e+00, 1.70230450e-01,
# -1.91683037e+00, 2.81482271e+00, -3.33403710e+00,
# 1.58763377e+00, 1.96265752e+00, -9.80790594e+02,
# -4.01386155e-01, 1.04828148e+00], dtype=float64)
solver_2 = jaxopt.LBFGS(fun=cost_fun, maxiter = 5000)
solver_2_sol = solver_2.run(par_log, X=F, Y=Y, YERR=sigma_Y, LB=lb_mat, UB=ub_mat)
solver_2_sol.params[:, 0]
# # Incorrect result
# DeviceArray([ 5.13495057e+04, -1.12978597e+04, -1.15845934e+04,
# 6.74886483e+01, 2.41560694e+03, -1.42666112e+04,
# 3.63855005e+04, 1.28521409e+05, -1.94682351e+04,
# -7.21753278e+04, 1.21914493e+02], dtype=float64)
Could you try also with jaxopt.LBFGS(..., linesearch="zoom")?
jaxopt.LBFGS and LBFGS-B from SciPy don't use the same line search technique so it's possible that we don't get the same results sometimes, if the function to be minimized is nonconvex.
I also did try with the zoom line search and did not get the correct results. You're right the problem is nonconvex.
Nevertheless, I found that I could use list(map(func, *args)) instead of vmap with jaxopt.scipy.minimize and it temporarily solves my problem
@richinex could be related to this, and therefore with the fixes in https://github.com/google/jaxopt/pull/323 and https://github.com/google/jaxopt/pull/350 it might make the results consistent with that of core JAX. Maybe you can give it a shot.