optimistix
optimistix copied to clipboard
Extracting intermediate function values/ losses from the solve
Dear optimistix team,
First of all, thank you for your effort in developing optimistix. I have recently transitioned from JAXOpt, and I love it!
I was wondering if it is possible to extract the loss/ function value history from the optimistic solve? In the code example below, it is easy to evaluate the intermediate losses when using the multi_step_solve
method, but it is much less efficient than the 'single_step_solve' approach. Using a jax.lax.scan
would definitely improve the performance over using a for
but I was wondering if there is a simpler method to extract this information in optimistix.
def rastrigin(x, args):
A = 10.0
y = A * x.shape[0] + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=0)
return y
# How can we extract the losses for a single_step_solve?
def single_step_solve(solver, y0):
sol = optx.minimise(rastrigin, solver, max_steps=2_000, y0=y0, throw=False)
return sol.value
def multi_step_solve(solver, y0):
# This is much less efficient, but it's easy to extract losses
current_sol = y0
for i in range(2_000):
current_sol = optx.minimise(rastrigin, solver, max_steps=1, y0=current_sol, throw=False).value
return current_sol