fax
fax copied to clipboard
(This pull request depends on the HS suite) adds: 1. Extragradient optimizer and tests (18/20 passing) 2. gradient descent/ascent, untested 3. updates in the pip package 4. a bug fix...
With the (somewhat) recent changes to how `jax` handles custom VJPs, it is now possible to define derivatives using the function for which we are defining the derivative. Since the...
Our current default solver amounts to [Richardson iteration](https://en.wikipedia.org/wiki/Modified_Richardson_iteration) with scaling factor "omega" equals to 1. Divergence is likely to occur if the spectral radius of df/dx in f(x,theta) = x...
I'm not sure what's the best design decision, but it may be confusing for users to have to express a "zero problem" into a "fixed-point" one. Namely, if you want...
Feel free to close this issue. It's more of a set of suggestions than it is an actual issue. With Clément's help and using fax as a base, I ended...
The `convergence_test` function is uses new_parameters, old_parameters as arguments: e.g. https://github.com/gehring/fax/blob/15619388f9362d6365eabf074ad27b50cf08d8fd/fax/constrained/constrained_test.py#L60 but in the doc string it accepts the solution tuple: https://github.com/gehring/fax/blob/1c68fc6745c5831bb55fa1e914f3de5efac85e51/fax/loop.py#L58 which includes other values such as the number...
The simplest way would to be to use [`tjax.custom_vjp`](https://github.com/NeilGirdhar/tjax/blob/a38695f328ec891fa2f4b78be23ec0abde34bb30/tjax/shims.py#L20), but this is currently not possible due to `tjax`'s python version requirements (#27).
Now that the two phase solver only returns the final solution, we should make sure there is a working mechanism for extracting from the forward and backward solvers any internal...