optax icon indicating copy to clipboard operation
optax copied to clipboard

Benchmark of lbfgs 10 times slower that scipy on cpu

Open organic-chemistry opened this issue 8 months ago • 4 comments

I wanted to migrate to optax for a problem I was working on and I noticed that lbgfs is really slower that scipy. I thus implemented a small benchmark on the Rosenbrock function, and indeed it is about 10x slower. I implemented the loop myself (method == Optax) or using the example from the doc : lbgfs method (Optax-help)

The naive implementation that I did is about 10x slower in optax. using the one from the doc (Optax-help) reduced it to 4-5x slower. My implementation has the same number of iteration than the one from scipy. Scipy method is L-BFGS-B however if I understood correctly in should be the same method that L-BFGS, except from the fact that you can include bounds. Here are the results (I ran them twice in a raw ):

JAX version  0.6.0
Optax version  0.2.5.dev
Method       | Time (s) |    NIT | Loss        
------------------------------------------------
L-BFGS-B     |    0.073 |    494 | 5.05e-10
Optax        |    0.616 |    495 | 3.88e-07
Optax-help   |    0.283 |    516 | 6.62e-07


Method       | Time (s) |    NIT | Loss        
------------------------------------------------
L-BFGS-B     |    0.081 |    494 | 5.05e-10
Optax        |    0.561 |    495 | 3.88e-07
Optax-help   |    0.283 |    516 | 6.62e-07

import numpy as np
import jax
import jax.numpy as jnp
import optax
from scipy.optimize import minimize
import time
import optax
import optax.tree_utils as otu

# Configure JAX to use 64-bit precision
jax.config.update("jax_enable_x64", True)

# Rosenbrock function (standard test function)
def rosenbrock(x):
    return np.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)

# JAX-compatible version
@jax.jit
def rosenbrock_jax(x):
    return jnp.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)

# Gradient for SciPy
def rosen_grad(x):
    grad = np.zeros_like(x)
    grad[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
    grad[-1] = 200*(x[-1]-x[-2]**2)
    for i in range(1, len(x)-1):
        grad[i] = 200*(x[i]-x[i-1]**2) - 400*x[i]*(x[i+1]-x[i]**2) - 2*(1-x[i])
    return grad

# Benchmark parameters
dim = 100  # Problem dimension
max_iter = 1000
gtol = 1e-6
x0 = np.zeros(dim)  # Initial guess


def run_scipy_benchmark(method):
    start = time.time()
    result = minimize(rosenbrock, x0, method=method, jac=rosen_grad,
                     options={'maxiter': max_iter, 'gtol': gtol})
    #print(result)
    return {
        'time': time.time() - start,
        'nit': result.nit,
        'nfev': result.njev,
        'loss': result.fun,
        'success': result.success
    }

def run_optax_benchmark(max_iter,gtol,patience=1):
    x = jnp.array(x0)
    optimizer = optax.lbfgs(learning_rate=1.0)
    value_and_grad = jax.value_and_grad(rosenbrock_jax)
    @jax.jit
    def step(params, opt_state):
        value, grads = value_and_grad(params)
        updates, opt_state = optimizer.update(grads, opt_state, params, value=value, grad=grads,value_fn=rosenbrock_jax)
        return optax.apply_updates(params, updates), opt_state, value
    
    opt_state = optimizer.init(x)
    best_loss = float('inf')
    start = time.time()
    
    for nfev in range(max_iter):
        x, opt_state, loss = step(x, opt_state)

        if loss < best_loss - gtol:
            best_loss = loss
            no_improvement = 0
        
        else:
            no_improvement += 1


        # Early stopping
        if no_improvement >= patience:
            break

    return {
        'time': time.time() - start,
        'nit': nfev,  # 1 evaluation per iteration
        'loss': best_loss,
        'success': best_loss < 1e-5
    }


def run_opt( opt, max_iter, tol):
    fun  = rosenbrock_jax
    init_params = jnp.zeros(dim)

    value_and_grad_fun = optax.value_and_grad_from_state(fun)

    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, state=state)
        updates, state = opt.update(
            grad, state, params, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)
        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, 'count')
        grad = otu.tree_get(state, 'grad')
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

    start = time.time()

    init_carry = (init_params, opt.init(init_params))
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return {
        'time': time.time() - start,
        'nit': otu.tree_get(final_state, 'count'),  # 1 evaluation per iteration
        'loss':otu.tree_l2_norm(jax.grad(fun)(final_params))    }


print("JAX version ", jax.__version__)
print("Optax version ",optax.__version__)


# Run benchmarks
methods = ['L-BFGS-B']#, 'trust-ncg']
results = {}
for i in range(2):
    for method in methods:
        results[method] = run_scipy_benchmark(method)
        
    results['Optax'] = run_optax_benchmark(max_iter=max_iter,gtol=gtol)
    opt = optax.lbfgs()

    results['Optax-help'] = run_opt(opt, max_iter=max_iter, tol=gtol)


    print(f"{'Method':<12} | {'Time (s)':<8} | {'   NIT':<6} | {'Loss':<12}")
    print("------------------------------------------------")
    for method, res in results.items():
        print(f"{method:<12} | {res['time']:>8.3f} | {res['nit']:>6} | {res['loss']:.2e}")
    print("\n")

organic-chemistry avatar May 12 '25 08:05 organic-chemistry

In the "Optax" case you're measuring the compile time for step and the eager overhead since each iteration takes < 1 ms, but you have two eager ops:

        x, opt_state, loss = step(x, opt_state)  # jax op 1

        if loss < best_loss - gtol:                        # jax op 2

In the "Optax-help" case you're measuring the compilation time for the while_loop, you can measure just the while_loop op by compiling it first:

    while_loop = jax.jit(lambda carry: jax.lax.while_loop(continuing_criterion, step, carry))
    _ = while_loop(init_carry)
    start = time.time()
    final_params, final_state = jax.block_until_ready(while_loop(init_carry))

After these changes (compiling step in "Optax" and compiling the while_loop in "Optax-help") I get:

Method       | Time (s) |    NIT | Loss        
------------------------------------------------
L-BFGS-B     |    0.170 |    497 | 6.84e-10
Optax        |    0.275 |    495 | 3.88e-07
Optax-help   |    0.009 |    516 | 6.62e-07

JAX has a (maybe not so well document) utility to sanity check if you're not actually measuring compilation: https://docs.jax.dev/en/latest/_autosummary/jax.log_compiles.html

Wrapping your timing step in this context manager showed me the while_loop is recompiled in the "Optax-help" the second time it's called (it's more obvious it's compiled the first time it's called).

    start = time.time()
    with jax.log_compiles():
      final_params, final_state = jax.block_until_ready(jax.lax.while_loop(
          continuing_criterion, step, init_carry
      ))

rdyro avatar May 12 '25 17:05 rdyro

Incidentally, if I pass compiled JAX functions to scipy.minimize directly (importantly both the loss and the gradient) I get as low as:

L-BFGS-B     |    0.068 |    501 | 4.59e-10

directly with scipy. It seems my machine is slower than yours, so maybe you can get scipy to run even faster with JAX compiled functions.

Btw, thanks for the full repro, that's very helpful!

rdyro avatar May 12 '25 18:05 rdyro

Ok, thanks for the feedback. I updated the code with your suggestions. Indeed now the Optax-help is way faster. What I am surprised is the fact that before updating the code, I was running the function twice. So I would have expected to compile the code in the first iteration, and then on the second having only the run time. In my case providing the jax jitified function to scipy did not improve the time. (Method scipy+jax).

Here are the new times and the full code:

JAX version  0.5.2
Optax version  0.2.4
Method       | Time (s) |    NIT | Loss        
------------------------------------------------
L-BFGS-B     |  6.5e-02 |    498 | 3.64e-10
scipy+jax    |  1.0e-01 |    495 | 5.83e-10
Optax        |  6.3e-01 |    495 | 3.88e-07
Optax-help   |  1.1e-03 |    516 | 6.62e-07


Method       | Time (s) |    NIT | Loss        
------------------------------------------------
L-BFGS-B     |  6.9e-02 |    498 | 3.64e-10
scipy+jax    |  7.2e-02 |    495 | 5.83e-10
Optax        |  5.6e-01 |    495 | 3.88e-07
Optax-help   |  1.2e-03 |    516 | 6.62e-07


import numpy as np
import jax
import jax.numpy as jnp
import optax
import time
import optax
import optax.tree_utils as otu
from jax import value_and_grad

from scipy.optimize import minimize


# Configure JAX to use 64-bit precision
jax.config.update("jax_enable_x64", True)

# Rosenbrock function (standard test function)
def rosenbrock(x):
    return np.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)

# JAX-compatible version
@jax.jit
def rosenbrock_jax(x):
    return jnp.sum(100.0*(x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)

# Gradient for SciPy
def rosen_grad(x):
    grad = np.zeros_like(x)
    grad[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
    grad[-1] = 200*(x[-1]-x[-2]**2)
    for i in range(1, len(x)-1):
        grad[i] = 200*(x[i]-x[i-1]**2) - 400*x[i]*(x[i+1]-x[i]**2) - 2*(1-x[i])
    return grad

# Benchmark parameters
dim = 100  # Problem dimension
max_iter = 1000
gtol = 1e-6
x0 = np.zeros(dim)  # Initial guess


def run_scipy_benchmark(method):
    start = time.time()
    result = minimize(rosenbrock, x0, method=method, jac=rosen_grad,
                     options={'maxiter': max_iter, 'gtol': gtol})
    #print(result)
    return {
        'time': time.time() - start,
        'nit': result.nit,
        'nfev': result.njev,
        'loss': result.fun,
        'success': result.success
    }

def run_scipy_benchmark_with_jax(method):
    value_and_grad = jax.jit(jax.value_and_grad(rosenbrock_jax))
    start = time.time()

    result = minimize(value_and_grad, x0, method=method, jac=True,
                     options={'maxiter': max_iter, 'gtol': gtol})
    #print(result)
    return {
        'time': time.time() - start,
        'nit': result.nit,
        'nfev': result.njev,
        'loss': result.fun,
        'success': result.success
    }

def run_optax_benchmark(max_iter,gtol,patience=1):
    x = jnp.array(x0)
    optimizer = optax.lbfgs(learning_rate=1.0)
    value_and_grad = jax.value_and_grad(rosenbrock_jax)
    @jax.jit
    def step(params, opt_state):
        value, grads = value_and_grad(params)
        updates, opt_state = optimizer.update(grads, opt_state, params, value=value, grad=grads,value_fn=rosenbrock_jax)
        return optax.apply_updates(params, updates), opt_state, value
    
    opt_state = optimizer.init(x)
    best_loss = float('inf')
    start = time.time()
    
    for nfev in range(max_iter):
        x, opt_state, loss = step(x, opt_state)

        if loss < best_loss - gtol:
            best_loss = loss
            no_improvement = 0
        
        else:
            no_improvement += 1


        # Early stopping
        if no_improvement >= patience:
            break

    return {
        'time': time.time() - start,
        'nit': nfev,  # 1 evaluation per iteration
        'loss': best_loss,
        'success': best_loss < 1e-5
    }


def run_opt( opt, max_iter, tol):
    fun  = rosenbrock_jax
    init_params = jnp.zeros(dim)

    value_and_grad_fun = optax.value_and_grad_from_state(fun)

    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, state=state)
        updates, state = opt.update(
            grad, state, params, value=value, grad=grad, value_fn=fun
        )
        params = optax.apply_updates(params, updates)
        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, 'count')
        grad = otu.tree_get(state, 'grad')
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

    start = time.time()

    init_carry = (init_params, opt.init(init_params))
    #final_params, final_state = jax.lax.while_loop(
    #    continuing_criterion, step, init_carry
    #)
    while_loop = jax.jit(lambda carry: jax.lax.while_loop(continuing_criterion, step, carry))
    _ = while_loop(init_carry)

    start = time.time()
    init_carry = (jnp.zeros(dim), opt.init(init_params))

    final_params, final_state = while_loop(init_carry)
    return {
        'time': time.time() - start,
        'nit': otu.tree_get(final_state, 'count'),  # 1 evaluation per iteration
        'loss':otu.tree_l2_norm(jax.grad(fun)(final_params))    }


print("JAX version ", jax.__version__)
print("Optax version ",optax.__version__)


# Run benchmarks
methods = ['L-BFGS-B']#, 'trust-ncg']
results = {}
for i in range(2):
    for method in methods:
        results[method] = run_scipy_benchmark(method)

    results["scipy+jax"] = run_scipy_benchmark_with_jax(method)
        
    results['Optax'] = run_optax_benchmark(max_iter=max_iter,gtol=gtol)
    opt = optax.lbfgs()

    results['Optax-help'] = run_opt(opt, max_iter=max_iter, tol=gtol)


    print(f"{'Method':<12} | {'Time (s)':<8} | {'   NIT':<6} | {'Loss':<12}")
    print("------------------------------------------------")
    for method, res in results.items():
        print(f"{method:<12} | {res['time']:>8.1e} | {res['nit']:>6} | {res['loss']:.2e}")
    print("\n")

organic-chemistry avatar May 13 '25 08:05 organic-chemistry

Nice! In the jax.lax.while_loop case the closures you're creating aren't particularly compilation cache friendly unfortunately.

You can also run JAX with jax.config.update("jax_explain_cache_misses", True) which should catch the trouble with caching the while compilation.

In my case providing the jax jitified function to scipy did not improve the time. (Method scipy+jax).

I think I used a separate value and separate grad function and called them once each before to compile them if you want to give that a try.

rdyro avatar May 13 '25 16:05 rdyro