fax
fax copied to clipboard
Support stochastic iterable functions (and other minor suggestions)
Feel free to close this issue. It's more of a set of suggestions than it is an actual issue. With Clément's help and using fax as a base, I ended up reimplementing the two-phase fixed point solver. I would prefer to use fax instead, but I can't do that unless fax has a few features:
- The ability to work with stochastic iterable functions. After a lot of discussion with Clément, that just means that you have two functions: one that is stochastic and is used to find the fixed point, and another that is the expected value of the first function, and is used in the
vjp
. - The ability to work with arbitrary PyTree-like objects for both the state, and the parameters. That's an easy fix: it means using
tree_map
andtree_reduce
in a few places. (e.g.,dout = fp_vjp_fn(dout)[0] + dvalue
needs to map usingjnp.add
.) - The object that defines the iteration should be pytree-like (I think that's fax's
params_func
). If it's not, then it needs to at least be hashable so that when it's passed as a static argument to a jitted function, it will always induce recompilation. - The ability to calculate a trajectory with the same inputs that are used to calculate a fixed point.
Just as an example (you're welcome to use any code that's useful to you), I ended up implementing a version of your two-phase solver with the above features.
The user-facing code is here.The combinator is here. I borrowed your tests. And here I show how it's used to write pretty code in an object-oriented setting.
I did use classes, but after all, the param_func
in fax is essentially an object. Unfortunately, a lot of the examples use closures rather than classes, which means that the closed over values are practically inaccessible, which complicates debugging.
The links you posted seem to give me a 404 error.
The ability to work with stochastic iterable functions. After a lot of discussion with Clément, that just means that you have two functions: one that is stochastic and is used to find the fixed point, and another that is the expected value of the first function, and is used in the vjp.
This should be easy to support and I can't think of a reason not to.
The ability to work with arbitrary PyTree-like objects for both the state, and the parameters. That's an easy fix: it means using tree_map and tree_reduce in a few places. (e.g., dout = fp_vjp_fn(dout)[0] + dvalue needs to map using jnp.add.)
Yes, we should have supported this from the start like we did for our conjugate gradient implementation.
The object that defines the iteration should be pytree-like (I think that's fax's params_func). If it's not, then it needs to at least be hashable so that when it's passed as a static argument to a jitted function, it will always induce recompilation.
Not 100% sure what you are referring to. Overall, I would like to re-implement two-phase such that it doesn't require we define a solver explicitly which would require all arguments play nicely with jax.jit
but I'm not sure how this is an issue in the current implementation.
The ability to calculate a trajectory with the same inputs that are used to calculate a fixed point.
Could you be a bit more specific about what you mean here. I don't think I understand what you mean.
Sorry about the links. I just reorganized the repo a few minutes ago.
Not 100% sure what you are referring to.
I'll explain what I was getting at. Currently, you have param_func
, which is a callable that takes the parameters, and produces another callable that takes the state (essentially) and produces the next state. One way to build param_func
would be as a function with some bound arguments (e.g., additional stopping conditions). The problem with this is if you have a function like this:
@jit(static_argnums=(0,))
def f(param_func, x, y, z):
two_phase_solver(param_func, ...)
it will trigger recompilation whenever param_func
's bound arguments change. You can get around this by making param_func
an object and registering it as a pytree, and then passing that object non-statically.
The other advantage of making things objects over using closures is that you can access the members.
In my design, I chose to put all of the other parameters to two_phase_solver
into the param_func
object to keep things simple.
Could you be a bit more specific about what you mean here. I don't think I understand what you mean.
Not a big deal, but since you're providing a while_loop
(for example) that finds the fixed point, it's often useful to also provide a scan
(for example) that iterates towards the fixed point and outputs the whole trajectory. In my experiments, it has really helped to be able to see what's going in order to debug my many mistakes.
In my design, I chose to put all of the other parameters to two_phase_solver into the param_func object to keep things simple.
I think you're essentially describing an interface similar to what I had in mind for the re-design of this API. The main idea is to no longer require two calls, one for the solver and one to solve. It should make for some much cleaner and intuitive code.
a scan (for example) that iterates towards the fixed point and outputs the whole trajectory.
I see. That might a bit outside the scope of what fax
is trying to do and it should be straightforward for most users to implement in a few lines. Is there a workflow or usage pattern that something a bit more integrated would help with?
I think you're essentially describing an interface similar to what I had in mind for the re-design of this API. The main idea is to no longer require two calls, one for the solver and one to solve. It should make for some much cleaner and intuitive code.
Cool!
I see. That might a bit outside the scope of what fax is trying to do and it should be straightforward for most users to implement in a few lines. Is there a workflow or usage pattern that something a bit more integrated would help with?
Fair enough. I understand what you're trying to do with FAX. I just thought that since the param_func
essentially defines an iterative function, it would be useful to include trajectories since it's related to iterative functions.
Is this partly covered by #26 and #30 ? Can we close this issue?
It's been made easier in #26 and #30 but there are still some things to do like properly support pytrees and add support to stochastic vjps.