optax icon indicating copy to clipboard operation
optax copied to clipboard

Add SPSA optimization method

Open ankit27kh opened this issue 2 years ago • 8 comments

The Simultaneous Perturbation Stochastic Approximation (SPSA) optimisation method is a faster optimisation method.

If the number of terms being optimized is p, then the finite-difference method takes 2p measurements of the objective function at each iteration (to form one gradient approximation), while SPSA takes only two measurements

It is also naturally suited for noisy measurements. Thus, it will be useful when simulating noisy systems.

The theory (and implementation) for SPSA is:

Furthermore, it is implemented:

More information: https://www.jhuapl.edu/SPSA/

ankit27kh avatar Jun 09 '22 15:06 ankit27kh

Sounds like a nice contribution, do you want to take a stab at it?

mtthss avatar Jul 14 '22 09:07 mtthss

@mtthss if no one is working on it, I would like to try. Can you provide me with some general points before I start, like which files to update, what to take care of etc., as I haven't contributed to optax before.

ankit27kh avatar Nov 18 '22 17:11 ankit27kh

Is there any updates on this? I have previously worked with SPSA in TF (https://github.com/tensorflow/quantum/pull/653) and would be interested in working on this but don't want to do redundant labor.

lockwo avatar Feb 01 '23 21:02 lockwo

Hi @lockwo, are you still interested? If you can implement SPSA, it'll be of great help!

ankit27kh avatar Feb 08 '24 06:02 ankit27kh

@ankit27kh : since there hasn't been activity for this in a year, I think it's safe for you to take over.

if you end up contributing this example, please do so to the contrib/ directory. Thanks!

fabianp avatar Feb 08 '24 07:02 fabianp

@fabianp I've created the following implementation of a pseudo-gradient estimator:

https://gist.github.com/carlosgmartin/0ee29182a17b35baf7d402ebdc797486

As noted in the function's docstring:

  • SPSA corresponds to the case where sampler=random.rademacher.
  • Gaussian smoothing corresponds to the case where sampler=random.normal.

I'd be happy to contribute this implementation to Optax. I could put it under optax.tree_utils with the other tree functions, if desired.

I also welcome any feedback on the code.

Note that this pseudo-gradient can be used in combination with any existing Optax optimizer: Its only role is to determine the gradient that is fed into the optimizer. Thus it acts as a replacement or analogue for jax.grad(f)(x, key).

carlosgmartin avatar Sep 19 '24 00:09 carlosgmartin

It would also be nice to have helper utility functions for estimation of the gradient via forward and central finite differences:

https://gist.github.com/carlosgmartin/a147b43f39633dcb0a985b51a5b1af0c

I'd be happy to contribute these as well.

carlosgmartin avatar Sep 19 '24 02:09 carlosgmartin

Thanks for looking into this @carlosgmartin. @q-berthet is contributing to a similar approach in #827 (which should be merged soon). Also the whole stochastic gradient estimator part of the codebase may be relevant (https://optax.readthedocs.io/en/latest/api/stochastic_gradient_estimators.html). Unfortunately the original authors of that part of the codebase have left and we are not sure that we will keep maintaining it given its poor support and adoption. But you may find interesting links there. It may be great to discuss between both of you @q-berthet and @carlosgmartin, how to integrate such effort.

About the forward/central difference schemes, similar discussions have happened in JAX (see e.g. https://github.com/jax-ml/jax/issues/15425). It seems that other users have expressed similar needs. If some libraries already propose such tools maybe it would be better to use those rather than reinventing them. (and maybe also check whether JAX ended up having such module).

Thanks again @carlosgmartin

vroulet avatar Sep 23 '24 17:09 vroulet