optax
optax copied to clipboard
Add SPSA optimization method
The Simultaneous Perturbation Stochastic Approximation (SPSA) optimisation method is a faster optimisation method.
If the number of terms being optimized is p, then the finite-difference method takes 2p measurements of the objective function at each iteration (to form one gradient approximation), while SPSA takes only two measurements
It is also naturally suited for noisy measurements. Thus, it will be useful when simulating noisy systems.
The theory (and implementation) for SPSA is:
- Described in James C. Spall, “An Overview of the Simultaneous Perturbation Method for Efficient Optimization”, 1998;
Furthermore, it is implemented:
- In the
noisyopt
package, specifically see the source code here; - In Qiskit, see the SPSA optimizer documentation.
More information: https://www.jhuapl.edu/SPSA/
Sounds like a nice contribution, do you want to take a stab at it?
@mtthss if no one is working on it, I would like to try. Can you provide me with some general points before I start, like which files to update, what to take care of etc., as I haven't contributed to optax before.
Is there any updates on this? I have previously worked with SPSA in TF (https://github.com/tensorflow/quantum/pull/653) and would be interested in working on this but don't want to do redundant labor.
Hi @lockwo, are you still interested? If you can implement SPSA, it'll be of great help!
@ankit27kh : since there hasn't been activity for this in a year, I think it's safe for you to take over.
if you end up contributing this example, please do so to the contrib/ directory. Thanks!
@fabianp I've created the following implementation of a pseudo-gradient estimator:
https://gist.github.com/carlosgmartin/0ee29182a17b35baf7d402ebdc797486
As noted in the function's docstring:
- SPSA corresponds to the case where
sampler=random.rademacher
. - Gaussian smoothing corresponds to the case where
sampler=random.normal
.
I'd be happy to contribute this implementation to Optax. I could put it under optax.tree_utils
with the other tree functions, if desired.
I also welcome any feedback on the code.
Note that this pseudo-gradient can be used in combination with any existing Optax optimizer: Its only role is to determine the gradient that is fed into the optimizer. Thus it acts as a replacement or analogue for jax.grad(f)(x, key)
.
It would also be nice to have helper utility functions for estimation of the gradient via forward and central finite differences:
https://gist.github.com/carlosgmartin/a147b43f39633dcb0a985b51a5b1af0c
I'd be happy to contribute these as well.
Thanks for looking into this @carlosgmartin. @q-berthet is contributing to a similar approach in #827 (which should be merged soon). Also the whole stochastic gradient estimator part of the codebase may be relevant (https://optax.readthedocs.io/en/latest/api/stochastic_gradient_estimators.html). Unfortunately the original authors of that part of the codebase have left and we are not sure that we will keep maintaining it given its poor support and adoption. But you may find interesting links there. It may be great to discuss between both of you @q-berthet and @carlosgmartin, how to integrate such effort.
About the forward/central difference schemes, similar discussions have happened in JAX (see e.g. https://github.com/jax-ml/jax/issues/15425). It seems that other users have expressed similar needs. If some libraries already propose such tools maybe it would be better to use those rather than reinventing them. (and maybe also check whether JAX ended up having such module).
Thanks again @carlosgmartin