jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

ScipyBoundedMinimize : recompilation of function to minimize at each instantiation?

Open jecampagne opened this issue 3 years ago • 2 comments

Discussed in https://github.com/google/jaxopt/discussions/220

Originally posted by jecampagne May 9, 2022 Hello,

Here a simple snippet:

import jax
import jax.numpy as jnp
from jax import vmap, jit
import numpy as np
jax.config.update("jax_enable_x64", True)

import jaxopt

@jit
def test(x):
    print("compile....")
    return -(x+2.)**2

tol = 1e-3
method = 'L-BFGS-B' # 'SLSQP'
options = {'disp':False,'ftol':tol, 'gtol':tol, 'maxiter':600}
jscMin=jaxopt.ScipyBoundedMinimize(fun=test,
                                           method=method, 
                                           tol=tol,
                                           options=options)

res = jscMin.run(jnp.array(10.),bounds=(-100,100))
print(">>> res1", res)

jscMin=jaxopt.ScipyBoundedMinimize(fun=test,
                                           method=method, 
                                           tol=tol,
                                           options=options)



res = jscMin.run(jnp.array(5.),bounds=(-100,100))
print(">>> res2", res)



res = jscMin.run(jnp.array(-5.),bounds=(-100,100))
print(">>> res3", res)

which leads to

compile....
>>> res1 OptStep(params=DeviceArray(100., dtype=float64), state=ScipyMinimizeInfo(fun_val=DeviceArray(-10404., dtype=float64, weak_type=True), success=True, status=0, iter_num=2))
compile....
>>> res2 OptStep(params=DeviceArray(100., dtype=float64), state=ScipyMinimizeInfo(fun_val=DeviceArray(-10404., dtype=float64, weak_type=True), success=True, status=0, iter_num=2))
>>> res3 OptStep(params=DeviceArray(-100., dtype=float64), state=ScipyMinimizeInfo(fun_val=DeviceArray(-9604., dtype=float64, weak_type=True), success=True, status=0, iter_num=2))

Is it foreseen that there is a compilation of test for each instantiation of ScipyBoundedMinimize ? that is to say to avoid recompilation then the user should use a single jscMin ?

The use-case is that, one needs sometimes to use different initial start to proceed to a choice of best parameters, so calling several times ScipyBoundedMinimize could have an overhead simply due to recompilation of the user function to minimize (here test but could be dramatically more complicated).

Any hint to optimize better ? Thanks.

jecampagne avatar May 10 '22 07:05 jecampagne

Hi,

Jaxopt follows the spirit of Jax: objects are stateless bag of hyper-parameters. Each of those hyper-parameters wraps a different implementation of the algorithm. For example, the hyper-parameter unroll or jit triggers a different implementation of the for loop with different memory requirements. Jax requires to know in advance the computation graph and the memory consumption of each instruction. Hence, it is expected that each creation of an object requires a jit compilation, since the code to be run can change.

ScipyBoundedMinimize is better off with that regard since compilation only occurs after the first call to run, not after every call (like it is usually the case for other solvers of Jaxopt). If you need to solve several problems with the solver I suggest you keep the same object Scipy. As illustrated below, you can even change the hyper-parameters of the method on the fly for ScipyWrappers objects:

import jax
import jax.numpy as jnp
from jax import vmap, jit
import numpy as np
jax.config.update("jax_enable_x64", True)

import jaxopt

@jit
def test(x):
    print("compile....")
    return -(x+2.)**2

tol = 1e-3
method = 'L-BFGS-B' # 'SLSQP'
options = {'disp':False,'ftol':tol, 'gtol':tol, 'maxiter':600}
jscMin=jaxopt.ScipyBoundedMinimize(fun=test,
                                           method=method, 
                                           tol=tol,
                                           options=options)
res = jscMin.run(jnp.array(10.),bounds=(-100,100))
print(">>> res1", res)

res = jscMin.run(jnp.array(10.),bounds=(-100,100))
print(">>> res2", res)

print('Change HYPER-PARAM')
jscMin.method = 'SLSQP'
res = jscMin.run(jnp.array(10.),bounds=(-100,100))
print(">>> res3", res)

jscMin=jaxopt.ScipyBoundedMinimize(fun=test,
                                   method=method, 
                                   tol=tol,
                                   options={})

print("")

res = jscMin.run(jnp.array(5.),bounds=(-100,100))
print(">>> res4", res)

res = jscMin.run(jnp.array(-5.),bounds=(-100,100))
print(">>> res5", res)

Algue-Rythme avatar May 24 '22 13:05 Algue-Rythme

Can this issue be closed?

mblondel avatar Jun 10 '22 16:06 mblondel