Support Optax solvers that include a linesearch
Needed to pass needed keyword arguments to Optax solvers, and make a new function for the value_fn to work.
Currently this works with optax.scale_by_zoom_linesearch and optax.lbfgs which uses that, but does not work with optax.scaly_by_backtracking_linesearch. That needs a fix in Optax that I think I have working, and I will send a PR to them about it.
I extended the list of minimizers in the tests. L-BFGS is tested for least squares as well, explicitly chaining SGD + line search is not. I also added SGD with backtracking linesearch commented out. After the fix on the Optax side, including this should also pass.
Fixes #121
The tests pass locally with the latest development version of Optax, but will fail with the current release of 0.2.4 which the workflow also uses.
The tests pass locally with the latest development version of Optax, but will fail with the current release of 0.2.4 which the workflow also uses.
Yes, the bug is already fixed on the optax side, it will be in the next release.
The new optax release is here: https://github.com/google-deepmind/optax/releases/tag/v0.2.5. I rebased on dev, resolved the conflicts in tests/helpers, and enabled the commented-out test.
...and merged :)