Running multiple optimizations at once
Hey,
i was wondering how one would use optimistix to run multiple optimizations from different initial guesses. Probably one would use jax.vmap in an interactive solve, right? Alternatively I was thinking that one could just minimize the sum of the errors from each individual optimization. But maybe this is an issue for higher order methods, since one has a lot of independent variables in this case.
You can vmap all Optimistix solves (minimisation, least-squares, root-finds...), this works out of the box without modifications. You can also vmap over an appropriately defined interactive solve.
By the way, I remember you were curious about L-BFGS - we now have something you could try in https://github.com/patrick-kidger/optimistix/pull/135.
Then I am doing something wrong? Because it doesnt seem to work with least-squares solvers because of a TypeError from jax.
I tried:
import jax.numpy as jnp
import optimistix
import numpy as np
def test_func(x, args):
return jnp.sum(x**2)
y0 = jnp.array(np.random.uniform(-100, 100, size=(10,10,10)))
solver = optimistix.GaussNewton(rtol=1e-12, atol=1e-12)
solution = jax.vmap(optimistix.minimise, in_axes=(None ,None, 0))(test_func, solver, y0)
And I also tried it using:
def run_minimise(y0):
solver = optimistix.GaussNewton(rtol=1e-12, atol=1e-12)
solution = optimistix.minimise(test_func, solver, y0)
return solution
solution = jax.vmap(run_minimise, in_axes=0)(y0)
TypeError: Output from batched function { lambda a:f32[10,10]; b:f32[10,10]. let c:f32[10,10] = mul b a d:f32[] = reduce_sum[axes=(0, 1)] c in (d,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type
Similar issues arise when I try this with an interactive solve, where i vmap over solver.init() and solver.step().
I remember that (e.g. when using jax.lax.scan()) one can solve this by using equinox.partition() but I dont see how a user could apply this here.
Ah yes. Ill definitely have a look at LBFGS :)
I'm at the lake, quick fix: use Equinox, specifically eqx.filter_vmap. This will work, I use it all the time.
jax.vmap requires that all of the inputs are PyTrees of arrays, which the solver is not: it is a class that has methods, and these are not arrays. For more on the topic, you can take a look here.
Edit: typo.
The issue with the second approach you took is that JAX does not know how to assemble a batched Solution. In the case of GaussNewton, this contains a jaxpr (the error you are seeing refers to that). You can avoid this error by returning just the value, but you might want to get the additional information! In that case, eqx.filter_vmap has you covered.
The snippet below illustrates this - it will fail if you set RETURN_VALUE to False.
import jax
import jax.numpy as jnp
import optimistix as optx
RETURN_VALUE = True
solver = optx.GaussNewton(rtol=1e-3, atol=1e-6)
def solve(y0):
if RETURN_VALUE:
return optx.least_squares(lambda y, args: jnp.sum(y**2), solver, y0).value
else:
# Return the solution object instead, which contains a jaxpr
return optx.least_squares(lambda y, args: jnp.sum(y**2), solver, y0)
y0s = jnp.arange(10)
jax.vmap(solve)(y0s)
Side note @johannahaffner I think we could fix this to work with jax.vmap by wrapping the jaxpr in a eqxi.Static. It's probably worth being compatible with jax.vmap by default; I try to encourage the feeling that the Equinox ecosystem is a library rather than a framework.
Yes, good point. I'll fix it!
Would it be preferable for solution.result to be an array of RESULTS or to provide some kind of reduction utility? (I'm having to write a batched nonlinear solve for an example ADI time stepper I'm working on.)
Solution.result already supports batching, it is an array with a batch axis. The thing we could improve is solely due to a jaxpr being present in a non-batched pytree leaf for some solvers.
For instance, if you want to access the results of individual solves you can do so with
solver_results = solution.result == optx.RESULTS.successful
anywhere in your code, this will give you an array of boolean values that you can use to filter for successful solutions, for instance. (You can of course also use this trick to convert any results message into a bool.)
Lurking here since I'm hoping to test the optimistix version of L-BFGS-B if/when it materialises (as the jaxopt one has a few problems!), but I have a relevant comment for this.
A recent academic work: https://github.com/Fra0013To/AD_MultiStartOpt/tree/main (and corresponding paper https://www.mdpi.com/2227-7390/12/8/1201 ) demonstrates that you can just treat your M starting points as a "single" point with a single optimizer (avoiding the overhead of running multiple optimizers with associated intermediates etc, meaning you can increase the batch size somewhat) as long as your objective function and return value(s) are correctly handled. This works specifically due to how these modern auto-diff capable libraries trace the compute graph - you still get correct gradients and descent directions for the "sub points".
Having done this myself for another problem, there's a few trade-offs. The larger the combined batch grows, the more likelihood that one of your starting points will run to iteration limit causing a wait on all points that have already converged- although perhaps not terrible vis-a-vis vmap (since you also have to wait). You also need to make a decision about how you aggregate the objective value for each of your sub-points, as it will affect downstream decisions about step-sizes potentially. If most points have converged, then the average of the objectives will cause the remaining ones to stop early (due to iter-to-iter tolerances being reached) - versus taking say, the sum, which will give you the former behaviour (running to iteration limit even if most points have already converged).
You can achieve this by either having the vmap inside your objective if that makes sense, or just perform your operation over the leading dimension if possible - though the paper strictly calls for the latter from my understanding:
def objective_vector(xs):
res = jnp.sum(xs**2, axis=-1)
return jnp.aggregation_op(res), res
def objective_vmap(xs):
def single_obj(x):
return jnp.sum(x**2)
res = jax.vmap(single_obj)(xs)
return jnp.aggregation_op(res), res
solver = optx.SomeSolver()
xs = m_random_guesses() # xs has shape (m, x0)
sol = optx.minimize(objective, solver, has_aux = True) # sol.aux for the individual evaluations for each x0
It would be actually useful to me if this type of functionality could be built in to the solvers - with additional heuristics for early stopping based on the aux data (e.g. stop the whole batch if no change in the best point after p iters, or if any point is within some optimality gap of a provided lower bound, etc.) - perhaps allowing a Callable for tol=?
since I'm hoping to test the optimistix version of L-BFGS-B if/when it materialises
So I have a working version of BFGS-B, not yet tested with our new L-BFGS Hessian update. I can open a draft PR and mention you, would be nice if you'd like to road-test it on real problems!
Regarding support in Optimistix for objective functions that implement an internal vmap - a user-defined callable argument for stopping criteria is something that has been floated to us, but we currently do not have immediate plans to add it. If we were to add it, then users could implement their own custom termination criteria and add heuristics that fit the problem and the solver.
FWIW I would be surprised if moving the vmap into the objective function itself actually eked out a performance gain in JAX, since vectorisation is one thing it does really well.
So I have a working version of BFGS-B, not yet tested with our new L-BFGS Hessian update. I can open a draft PR and mention you, would be nice if you'd like to road-test it on real problems!
Absolutely, I can give it a whirl!
Regarding support in Optimistix for objective functions that implement an internal
vmap- a user-defined callable argument for stopping criteria is something that has been floated to us, but we currently do not have immediate plans to add it. If we were to add it, then users could implement their own custom termination criteria and add heuristics that fit the problem and the solver.FWIW I would be surprised if moving the
vmapinto the objective function itself actually eked out a performance gain in JAX, since vectorisation is one thing it does really well.
I can believe this about moving vmap into the objective function (as far as compute time goes). Moving the vmap into the objective does for now fix the initial problem, in that there is no longer an issue re: a "batched solution".
Having said this, I thought I did manage to recover some memory by (re)moving it - I wonder how the positioning of the vmap, or not using it at all and being explicit about the vectorisation, actually affects the compute graph if the whole thing is wrapped in a jax.jit() up the top???
Absolutely, I can give it a whirl!
Great!
Having said this, I thought I did manage to recover some memory by (re)moving it - I wonder how the positioning of the
vmap, or not using it at all and being explicit about the vectorisation, actually affects the compute graph if the whole thing is wrapped in a jax.jit() up the top???
So purely in terms of runtime, without jit in the picture, they are exactly the same:
import jax
import jax.numpy as jnp
def rosenbrock(y, constant_factor):
x1, x2 = y
return (constant_factor - x1) ** 2 + 100 * (x2 - x1**2) ** 2
def vmapped_rosenbrock(ys, constant_factor):
return jax.vmap(rosenbrock, in_axes=(0, None))(ys, constant_factor)
constant_factor = 1
exponents = jnp.arange(1, 8)
for e in exponents:
batch_dimension = 10**e
many_y0s = jnp.zeros((batch_dimension, 2))
print(f"testing for batch dimension {batch_dimension}")
%timeit jax.vmap(rosenbrock, in_axes=(0, None))(many_y0s, constant_factor)
%timeit vmapped_rosenbrock(many_y0s, constant_factor)
print("")
# testing for batch dimension 10
# 923 μs ± 12.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 915 μs ± 9.51 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# testing for batch dimension 100
# 916 μs ± 4.69 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 910 μs ± 2.31 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# testing for batch dimension 1000
# 967 μs ± 3.33 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 970 μs ± 3.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# testing for batch dimension 10000
# 974 μs ± 4.63 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 988 μs ± 20.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# testing for batch dimension 100000
# 1.15 ms ± 13 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.14 ms ± 17.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# testing for batch dimension 1000000
# 2.54 ms ± 26.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 2.54 ms ± 15.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# testing for batch dimension 10000000
# 23.4 ms ± 389 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 22.9 ms ± 89.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
When jitting the whole thing, the picture does not change - we're just way faster:
# testing for batch dimension 10
# 4.7 μs ± 34.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 4.67 μs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# testing for batch dimension 100
# 4.77 μs ± 37.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 4.8 μs ± 96.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# testing for batch dimension 1000
# 6.86 μs ± 55.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 6.85 μs ± 42.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# testing for batch dimension 10000
# 7.48 μs ± 93.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 7.41 μs ± 26.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# testing for batch dimension 100000
# 16.5 μs ± 998 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 16.4 μs ± 179 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# testing for batch dimension 1000000
# 209 μs ± 7.47 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 200 μs ± 6.01 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# testing for batch dimension 10000000
# 2.61 ms ± 51.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 2.7 ms ± 63.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I think that makes sense, you're not drastically changing the computation graph - just moving the vmap one level inwards.
In particular, jax.vmap is written in such a way that the individual operations (for each element in y0 in this example) do not affect each other, but if you're vmapping inside the objective that now becomes the user's job, as you have pointed out.
I agree that it would be nice if regular JAX transformations worked out of the box here, though! For now I recommend eqx.filter_{vmap, jit} as a workaround, but I started working on a fix :) This only affects solvers that contain a Jacobian operator that is not materialised (and hence contains a jaxpr), in particular this affects the GaussNewton class. It already seems to work out of the box for the minimisers.
Absolutely, I can give it a whirl!
Great!
I'll keep an eye out for it :)
I think that makes sense, you're not drastically changing the computation graph - just moving the
vmapone level inwards.In particular,
jax.vmapis written in such a way that the individual operations (for each element iny0in this example) do not affect each other, but if you're vmapping inside the objective that now becomes the user's job, as you have pointed out.
It seems that commenting here was ultimately the right call, as you've clarified for me something about JAX that I hadn't fully internalised correctly, so cheers for that.
@cjchristopher I do have a version that requires a linear solve (and works well for BFGS-B), but upon re-reading the literature for the limited memory variant I don't think that my current version would scale so well. The original published version manages to only invert a linear operator that is square in the history length.
We do have all the building blocks - a direct method to construct the limited memory Hessian operator that is needed, code that identifies the generalised Cauchy point and functions that truncate the step length to the maximum feasible step length. What is missing is solving the projected step with the ingredients that also go into the generation of our inverse Hessian operator (certain products of elements of the gradient and step histories).
If you have a pressing need for this, I'd be happy to offer guidance to implementing what is missing. Otherwise I'll probably get around to it in a couple of weeks :)