jax
jax copied to clipboard
Implement fixed point routine jax.scipy.optimize.fixed_point
In the PR I included a general implementation of the jax.scipy.optimize.fixed_point
function with similar API to the one of the original scipy
package. The implementation has a custom VJP defined with the results from the implicit function theorem.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
@Maintainers: I do not really see the reason why the doc generation is failing. In the log I find no error related to this PR?
I think this would be a good fit for JAXopt. Both root finding and fixed point resolution are in the scope of the library (see documentation). There would be a bit of work to stick to our API but it would be worth it (the optimization loop and implicit diff would be taken care of).
Minor detail: Although root finding and fixed point resolution are equivalent, Steffensen's method seems to be described as root finding in text books. So maybe it's better to document it that way.
@JanLuca – what do you think of contributing this to JAXopt instead, as @mblondel suggests?
Generally speaking, when a component of jax.scipy.optimize
involves a substantial implementation that would significantly increase the maintenance load of the jax core, we try to see if other libraries like JAXopt might take it on instead. As an existing example of this, we're planning to shed jax.scipy.optimize.minimize
because JAXopt already covers L-BFGS.