optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Extracting intermediate function values/ losses from the solve

Open itk22 opened this issue 3 months ago • 4 comments

Dear optimistix team,

First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!

I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the multi_step_solve method, but it is much less efficient than the 'single_step_solve' approach. Using a jax.lax.scan would definitely improve the performance over using a for but I was wondering if there is a simpler method to extract this information in optimistix.

def rastrigin(x, args):
    A = 10.0
    y = A * x.shape[0] + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=0)
    return y

# How can we extract the losses for a single_step_solve?
def single_step_solve(solver, y0):
    sol = optx.minimise(rastrigin, solver, max_steps=2_000, y0=y0, throw=False)
    return sol.value

def multi_step_solve(solver, y0):
    # This is much less efficient, but it's easy to extract losses
    current_sol = y0
    for i in range(2_000):
        current_sol = optx.minimise(rastrigin, solver, max_steps=1, y0=current_sol, throw=False).value
    return current_sol

itk22 avatar Mar 22 '24 13:03 itk22