Benchmark of lbfgs 10 times slower that scipy on cpu
I wanted to migrate to optax for a problem I was working on and I noticed that lbgfs is really slower that scipy. I thus implemented a small benchmark on the Rosenbrock function, and indeed it is about 10x slower. I implemented the loop myself (method == Optax) or using the example from the doc : lbgfs method (Optax-help)
The naive implementation that I did is about 10x slower in optax. using the one from the doc (Optax-help) reduced it to 4-5x slower. My implementation has the same number of iteration than the one from scipy. Scipy method is L-BFGS-B however if I understood correctly in should be the same method that L-BFGS, except from the fact that you can include bounds. Here are the results (I ran them twice in a raw ):
JAX version 0.6.0
Optax version 0.2.5.dev
Method | Time (s) | NIT | Loss
------------------------------------------------
L-BFGS-B | 0.073 | 494 | 5.05e-10
Optax | 0.616 | 495 | 3.88e-07
Optax-help | 0.283 | 516 | 6.62e-07
Method | Time (s) | NIT | Loss
------------------------------------------------
L-BFGS-B | 0.081 | 494 | 5.05e-10
Optax | 0.561 | 495 | 3.88e-07
Optax-help | 0.283 | 516 | 6.62e-07
import numpy as np
import jax
import jax.numpy as jnp
import optax
from scipy.optimize import minimize
import time
import optax
import optax.tree_utils as otu
# Configure JAX to use 64-bit precision
jax.config.update("jax_enable_x64", True)
# Rosenbrock function (standard test function)
def rosenbrock(x):
return np.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)
# JAX-compatible version
@jax.jit
def rosenbrock_jax(x):
return jnp.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)
# Gradient for SciPy
def rosen_grad(x):
grad = np.zeros_like(x)
grad[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
grad[-1] = 200*(x[-1]-x[-2]**2)
for i in range(1, len(x)-1):
grad[i] = 200*(x[i]-x[i-1]**2) - 400*x[i]*(x[i+1]-x[i]**2) - 2*(1-x[i])
return grad
# Benchmark parameters
dim = 100 # Problem dimension
max_iter = 1000
gtol = 1e-6
x0 = np.zeros(dim) # Initial guess
def run_scipy_benchmark(method):
start = time.time()
result = minimize(rosenbrock, x0, method=method, jac=rosen_grad,
options={'maxiter': max_iter, 'gtol': gtol})
#print(result)
return {
'time': time.time() - start,
'nit': result.nit,
'nfev': result.njev,
'loss': result.fun,
'success': result.success
}
def run_optax_benchmark(max_iter,gtol,patience=1):
x = jnp.array(x0)
optimizer = optax.lbfgs(learning_rate=1.0)
value_and_grad = jax.value_and_grad(rosenbrock_jax)
@jax.jit
def step(params, opt_state):
value, grads = value_and_grad(params)
updates, opt_state = optimizer.update(grads, opt_state, params, value=value, grad=grads,value_fn=rosenbrock_jax)
return optax.apply_updates(params, updates), opt_state, value
opt_state = optimizer.init(x)
best_loss = float('inf')
start = time.time()
for nfev in range(max_iter):
x, opt_state, loss = step(x, opt_state)
if loss < best_loss - gtol:
best_loss = loss
no_improvement = 0
else:
no_improvement += 1
# Early stopping
if no_improvement >= patience:
break
return {
'time': time.time() - start,
'nit': nfev, # 1 evaluation per iteration
'loss': best_loss,
'success': best_loss < 1e-5
}
def run_opt( opt, max_iter, tol):
fun = rosenbrock_jax
init_params = jnp.zeros(dim)
value_and_grad_fun = optax.value_and_grad_from_state(fun)
def step(carry):
params, state = carry
value, grad = value_and_grad_fun(params, state=state)
updates, state = opt.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)
params = optax.apply_updates(params, updates)
return params, state
def continuing_criterion(carry):
_, state = carry
iter_num = otu.tree_get(state, 'count')
grad = otu.tree_get(state, 'grad')
err = otu.tree_l2_norm(grad)
return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))
start = time.time()
init_carry = (init_params, opt.init(init_params))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return {
'time': time.time() - start,
'nit': otu.tree_get(final_state, 'count'), # 1 evaluation per iteration
'loss':otu.tree_l2_norm(jax.grad(fun)(final_params)) }
print("JAX version ", jax.__version__)
print("Optax version ",optax.__version__)
# Run benchmarks
methods = ['L-BFGS-B']#, 'trust-ncg']
results = {}
for i in range(2):
for method in methods:
results[method] = run_scipy_benchmark(method)
results['Optax'] = run_optax_benchmark(max_iter=max_iter,gtol=gtol)
opt = optax.lbfgs()
results['Optax-help'] = run_opt(opt, max_iter=max_iter, tol=gtol)
print(f"{'Method':<12} | {'Time (s)':<8} | {' NIT':<6} | {'Loss':<12}")
print("------------------------------------------------")
for method, res in results.items():
print(f"{method:<12} | {res['time']:>8.3f} | {res['nit']:>6} | {res['loss']:.2e}")
print("\n")
In the "Optax" case you're measuring the compile time for step and the eager overhead since each iteration takes < 1 ms, but you have two eager ops:
x, opt_state, loss = step(x, opt_state) # jax op 1
if loss < best_loss - gtol: # jax op 2
In the "Optax-help" case you're measuring the compilation time for the while_loop, you can measure just the while_loop op by compiling it first:
while_loop = jax.jit(lambda carry: jax.lax.while_loop(continuing_criterion, step, carry))
_ = while_loop(init_carry)
start = time.time()
final_params, final_state = jax.block_until_ready(while_loop(init_carry))
After these changes (compiling step in "Optax" and compiling the while_loop in "Optax-help") I get:
Method | Time (s) | NIT | Loss
------------------------------------------------
L-BFGS-B | 0.170 | 497 | 6.84e-10
Optax | 0.275 | 495 | 3.88e-07
Optax-help | 0.009 | 516 | 6.62e-07
JAX has a (maybe not so well document) utility to sanity check if you're not actually measuring compilation: https://docs.jax.dev/en/latest/_autosummary/jax.log_compiles.html
Wrapping your timing step in this context manager showed me the while_loop is recompiled in the "Optax-help" the second time it's called (it's more obvious it's compiled the first time it's called).
start = time.time()
with jax.log_compiles():
final_params, final_state = jax.block_until_ready(jax.lax.while_loop(
continuing_criterion, step, init_carry
))
Incidentally, if I pass compiled JAX functions to scipy.minimize directly (importantly both the loss and the gradient) I get as low as:
L-BFGS-B | 0.068 | 501 | 4.59e-10
directly with scipy. It seems my machine is slower than yours, so maybe you can get scipy to run even faster with JAX compiled functions.
Btw, thanks for the full repro, that's very helpful!
Ok, thanks for the feedback. I updated the code with your suggestions. Indeed now the Optax-help is way faster. What I am surprised is the fact that before updating the code, I was running the function twice. So I would have expected to compile the code in the first iteration, and then on the second having only the run time. In my case providing the jax jitified function to scipy did not improve the time. (Method scipy+jax).
Here are the new times and the full code:
JAX version 0.5.2
Optax version 0.2.4
Method | Time (s) | NIT | Loss
------------------------------------------------
L-BFGS-B | 6.5e-02 | 498 | 3.64e-10
scipy+jax | 1.0e-01 | 495 | 5.83e-10
Optax | 6.3e-01 | 495 | 3.88e-07
Optax-help | 1.1e-03 | 516 | 6.62e-07
Method | Time (s) | NIT | Loss
------------------------------------------------
L-BFGS-B | 6.9e-02 | 498 | 3.64e-10
scipy+jax | 7.2e-02 | 495 | 5.83e-10
Optax | 5.6e-01 | 495 | 3.88e-07
Optax-help | 1.2e-03 | 516 | 6.62e-07
import numpy as np
import jax
import jax.numpy as jnp
import optax
import time
import optax
import optax.tree_utils as otu
from jax import value_and_grad
from scipy.optimize import minimize
# Configure JAX to use 64-bit precision
jax.config.update("jax_enable_x64", True)
# Rosenbrock function (standard test function)
def rosenbrock(x):
return np.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)
# JAX-compatible version
@jax.jit
def rosenbrock_jax(x):
return jnp.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)
# Gradient for SciPy
def rosen_grad(x):
grad = np.zeros_like(x)
grad[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
grad[-1] = 200*(x[-1]-x[-2]**2)
for i in range(1, len(x)-1):
grad[i] = 200*(x[i]-x[i-1]**2) - 400*x[i]*(x[i+1]-x[i]**2) - 2*(1-x[i])
return grad
# Benchmark parameters
dim = 100 # Problem dimension
max_iter = 1000
gtol = 1e-6
x0 = np.zeros(dim) # Initial guess
def run_scipy_benchmark(method):
start = time.time()
result = minimize(rosenbrock, x0, method=method, jac=rosen_grad,
options={'maxiter': max_iter, 'gtol': gtol})
#print(result)
return {
'time': time.time() - start,
'nit': result.nit,
'nfev': result.njev,
'loss': result.fun,
'success': result.success
}
def run_scipy_benchmark_with_jax(method):
value_and_grad = jax.jit(jax.value_and_grad(rosenbrock_jax))
start = time.time()
result = minimize(value_and_grad, x0, method=method, jac=True,
options={'maxiter': max_iter, 'gtol': gtol})
#print(result)
return {
'time': time.time() - start,
'nit': result.nit,
'nfev': result.njev,
'loss': result.fun,
'success': result.success
}
def run_optax_benchmark(max_iter,gtol,patience=1):
x = jnp.array(x0)
optimizer = optax.lbfgs(learning_rate=1.0)
value_and_grad = jax.value_and_grad(rosenbrock_jax)
@jax.jit
def step(params, opt_state):
value, grads = value_and_grad(params)
updates, opt_state = optimizer.update(grads, opt_state, params, value=value, grad=grads,value_fn=rosenbrock_jax)
return optax.apply_updates(params, updates), opt_state, value
opt_state = optimizer.init(x)
best_loss = float('inf')
start = time.time()
for nfev in range(max_iter):
x, opt_state, loss = step(x, opt_state)
if loss < best_loss - gtol:
best_loss = loss
no_improvement = 0
else:
no_improvement += 1
# Early stopping
if no_improvement >= patience:
break
return {
'time': time.time() - start,
'nit': nfev, # 1 evaluation per iteration
'loss': best_loss,
'success': best_loss < 1e-5
}
def run_opt( opt, max_iter, tol):
fun = rosenbrock_jax
init_params = jnp.zeros(dim)
value_and_grad_fun = optax.value_and_grad_from_state(fun)
def step(carry):
params, state = carry
value, grad = value_and_grad_fun(params, state=state)
updates, state = opt.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)
params = optax.apply_updates(params, updates)
return params, state
def continuing_criterion(carry):
_, state = carry
iter_num = otu.tree_get(state, 'count')
grad = otu.tree_get(state, 'grad')
err = otu.tree_l2_norm(grad)
return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))
start = time.time()
init_carry = (init_params, opt.init(init_params))
#final_params, final_state = jax.lax.while_loop(
# continuing_criterion, step, init_carry
#)
while_loop = jax.jit(lambda carry: jax.lax.while_loop(continuing_criterion, step, carry))
_ = while_loop(init_carry)
start = time.time()
init_carry = (jnp.zeros(dim), opt.init(init_params))
final_params, final_state = while_loop(init_carry)
return {
'time': time.time() - start,
'nit': otu.tree_get(final_state, 'count'), # 1 evaluation per iteration
'loss':otu.tree_l2_norm(jax.grad(fun)(final_params)) }
print("JAX version ", jax.__version__)
print("Optax version ",optax.__version__)
# Run benchmarks
methods = ['L-BFGS-B']#, 'trust-ncg']
results = {}
for i in range(2):
for method in methods:
results[method] = run_scipy_benchmark(method)
results["scipy+jax"] = run_scipy_benchmark_with_jax(method)
results['Optax'] = run_optax_benchmark(max_iter=max_iter,gtol=gtol)
opt = optax.lbfgs()
results['Optax-help'] = run_opt(opt, max_iter=max_iter, tol=gtol)
print(f"{'Method':<12} | {'Time (s)':<8} | {' NIT':<6} | {'Loss':<12}")
print("------------------------------------------------")
for method, res in results.items():
print(f"{method:<12} | {res['time']:>8.1e} | {res['nit']:>6} | {res['loss']:.2e}")
print("\n")
Nice! In the jax.lax.while_loop case the closures you're creating aren't particularly compilation cache friendly unfortunately.
You can also run JAX with jax.config.update("jax_explain_cache_misses", True) which should catch the trouble with caching the while compilation.
In my case providing the jax jitified function to scipy did not improve the time. (Method scipy+jax).
I think I used a separate value and separate grad function and called them once each before to compile them if you want to give that a try.