fax
fax copied to clipboard
Fixed-point vs zero form
I'm not sure what's the best design decision, but it may be confusing for users to have to express a "zero problem" into a "fixed-point" one. Namely, if you want to specify $F(x, \theta) = 0$, the user would have to define a dummy function of the form $x = x + F(x, \theta)$ to make it compatible with the two-phase interface. We already have a lot of flags, so I'm not sure that adding one more is the right solution. If there is no programmatic solution, we should at least highlight it in the doc or with examples.
I'll have to think a bit more carefully about this but is using the two-phase method on non-fixed point problems advisable? At this point it might be better to use jax.lax.custom_root
and use a possibly more efficient gradient solver. WDYT?
I tend to put the FP and root finding perspective in the same basket. I think the only difference is that our interface defaults to successive approximation if no forward solver is passed.
We can perhaps provide a decorator. Instead of having to write:
def sqrt_zero(a):
def _sqrt_zero(x):
return x + (a - x**2)
return _sqrt_zero
Provide a decorator zero_problem
that you use as:
def sqrt_zero(a):
@zero_problem
def _sqrt_zero(x):
return a - x**2
return _sqrt_zero
Here's more complete example:
import jax.numpy as jnp
from fax import implicit
def root_finding(param_func):
def _fp_param_func(params):
fn_op = param_func(params)
def _fp_operator(x):
return x + fn_op(x)
return _fp_operator
return _fp_param_func
@root_finding
def square_root(a):
def _square_root(x):
return a - x**2
return _square_root
print(implicit.two_phase_solve(square_root, init_xs=jnp.array(1.), params=jnp.array(4), solvers=(lambda f, x0, a: jnp.sqrt(a),)))
# Returns 2.0
The more I think about the more I feel like transforming a root finding problem in to a fixed-point finding problem to be impractical for the only reason that our two phase fixed-point implementation assumes we are dealing with an attractive fixed-point. If we provide some wrappers to solve root problems, it won't be obvious to the user when to expect the default solvers to work and when they shouldn't without a good understanding of what is happening behind the scene.
We could either 1) implement a general function for handling implicit differentiation for root finding with no default solvers like jax.lax.custom_root
, and use the existing default solvers to re-implement two_phase_solve
using that api, 2) maintain both a fixed-point specific method and a separate custom root method, which would duplicate some code but we would maintain a 1-to-1 correspondence between two_phase_solve
and Christianson's method.
I think I am leaning towards 1) as long as we can keep the two_phase_solve
API unchanged (or nearly unchanged) and we can keep supporting higher-order derivatives of fixed-points. WDYT?
Also, while we're on the topic, jax.lax.custom_root
expects jax transformable solvers (as far as I can tell) which make it ill suited for non-jit'able cases where the solver might be some external program. Until jax natively supports XLA's custom_call
op, we'll want to use our own implementation.
You make a good point regarding the use of successive approximation as a default "forward" solver, and how this is ill-suited and confusing for the user under the root-finding perspective. I made that mistake myself! If I understand correctly, your proposal is to ask the user to specify its own forward solver as a non-optional argument to two_phase_solve
. I think it makes sense. The drawback is of course that it puts more burden on the user when invoking two_phase_solve
, but it also makes things safer. If necessary, we can always provide a small abstraction/wrapper around two_phase_solve
, specialized for each case.
I don't think that the code duplication route is the way to go here (and performance-wise, I'm confident in the XLA magic to close the potential gap between the fp backend vs pure root one).
It would be good to refactor out default_solver
https://github.com/gehring/fax/blob/be058d17fb7d650ba1ebc093179a0f735a738054/fax/implicit/twophase.py#L13-L48 in this case. Also, this piece of code can be used in other contexts where we need to solve for lin. systems in a matrix-free fashion.
*what I meant was more the part about the solver derived from the Neumann series perspective. https://github.com/gehring/fax/blob/be058d17fb7d650ba1ebc093179a0f735a738054/fax/implicit/twophase.py#L112-L120