Passing a state between vector_field calls
Hi, first of all, thanks for this great library!
I have a use-case in which I want to solve some ODE, but at every call to the vector_field it requires running some inner iterative solver $s$ to get the time derivative of the state, i.e. $\frac{dy}{dt}=s(y)$. I would now like to give the inner solver the previous vector_field output as a warmstart, is there a clever way to do this? Ideally there would be some option to return some arbitrary auxiliary output when calling the vector_field that is directly passed over to the next call.
So far I implemented a custom Euler integrator that supports this by assuming that the state is structured as a tuple of the actual state and auxiliary outputs, and only integrating the actual state while directly passing the auxiliary output. However, I was wondering if there is some simpler way to do it, as adapting my changes to e.g. the Runge-Kutta implementation would be quite tedious. I was also considering somehow passing the auxiliary output via the SolverState but this again would require many changes it seems.
Thanks in advance!
This is a request that's come up a few times now! See also #199, #462, and I think a few other issues besides.
Right now this is possible through an advanced API surface, by creating a custom solver. Some quick pseudocode:
def vector_field(t, y, args, *, paulus_state):
...
class PaulusSolver(diffrax.AbstractWrappedSolver):
def step(self, terms, ..., solver_state, ...):
paulus_state, sub_solver_state = solver_state
assert isinstance(terms, ODETerm)
terms = eqx.tree_at(lambda t: t.vector_field, terms, ft.partial(terms.vector_field, paulus_state=paulus_state)
..., new_sub_solver_state, ... = self.solver.step(terms, ...)
solver_state = (new_paulus_state, new_sub_solver_state)
return ... , solver_state, ...
solver = PaulusSolver(Tsit5())
terms = ODETerm(vector_field)
diffrax.diffeqsolve(terms, solver, ...)
Tagging @lockwo this is a thing we should think about as part of the new stateful stuff. (I need to take a look at your PR!) Is this a thing we can easily support here do you think?
This hijacking of solver state was also how we had the stateful control initially. The PR does make the terms stateful (now having init/passing state), but currently only for the .contr part (not the .vf), which will probably be invisible for most users. But once the stateful control PR is in, we've crossed the rubicon and introduced a breaking change which makes terms stateful, so expanding this statefulness to include stateful vector fields (not just stateful controls) seems like a reasonable next step.
This also makes sense conceptually, since it's the vf that is actually stateful (rather than the solver that we are leeching the state off of). This change wouldn't be particularly difficult to add imo (just a lot of adding arguments to functions, making solvers and what not take an additional "vf_state" argument, just like I added "control_state"), although I haven't thought of if there are any implications for the broader codebase/concerns with autodiff/anything like that.
This might be a more visible breaking change though, since now passing a vf to a term would probably have signature def vf(t, y, args, state) instead of def vf(t, y, args) (or maybe we deprecate args and just pass around the state? idk, just an idea).
On API design
Aha! I've just dug up https://github.com/patrick-kidger/diffrax/issues/490#issuecomment-2353442995 where I have a more fleshed-out version of my earlier pseudocode.
Your point about breaking changes and visibility is well made. I don't think I want to make this an 'obvious' API -- it's very much an advanced feature that makes it really easy to shoot yourself in the foot. More precisely: I mean that it shouldn't be an API that proliferates its way into many surfaces -- rather, this kind of statefulness should be handled in just one place. Whether that's just in a solver wrapper or something more designed for purpose.
I feel like so far we've done a good job of this in Diffrax! Solvers and step size controllers are different, event handling happens in one place, etc. Using any particular more advanced piece of functionality happens fairly straightforwardly.
So! I think we're left with an interesting API design question :) How to handle this kind of thing neatly.
A bad suggestion
Fundamentally we basically just want a way to update arbitrary user state on a per-step basis. So here's a first possible (bad) suggestion (that I will improve in a second): add diffeqsolve(..., update_args=...), so that at every step we update args = update_args(args).
Vector fields which want to cache iterative subroutines can then do so through this (more precisely, they can do the iteration in update_args, and then provide that very-close-to-the-desired starting point to each vector field call -- i.e. the vf itself doesn't try to return an updated state, which isn't really even a well-defined notion given that different solvers may call vfs in all kinds of different ways).
This isn't perfect as I don't think it really generalises to the case of stateful controls, but I think it's leading us in the right direction: it's pleasingly simple and self-contained, uses existing abstractions, and works using some kind of per-step state that is then plumbed into the vector field and control -- not to have the vector field itself return state.
A good (?) suggestion
So, finessing the above: maybe instead of having stateful controls or stateful vector fields, what we really want are stateful terms? This is conventiently an abstraction that holds both vector fields, and that is visible to the main diffeqsolve integration loop.
Then the main integration loop could be adjusted from something like this:
def body_fun(terms, ...):
... = solver.step(terms, ...)
into something like this:
def body_fun(terms, term_state, ...):
step_terms, term_state = terms.step(term_state)
... = solver.step(step_terms, ...)
and inside of term.step is where we could arrange to do things like update our starting points for iterative solves, or to cache Brownian evaluations.
WDYT? The above is a bit of a rough draft still so I might have gotten things wrong.
I think stateful terms are definitely a necessary condition, but I'm not sure they are sufficient (in the ideal case). Terms can manage their own state and adding the vf state to the existing stateful terms I wrote would look basically the same as what you wrote, there's just a term state that manages everything (at least looks the same with respect to terms).
However, I'm not sure how I see this conveniently working inside the solver. Since step_terms is just a collection of the new terms (but doesn't have term state information, even though it can precompute/setup everything in the term.step), then in my solver if I need to do something that requires updating the state, how should that be done? That is to say, if abstract solver can call (and thus update) the state of the vector field (or of the terms) then it seems like the term_state needs to be passed around to the solver as well. E.g. maybe my vector field gets called multiple time with different parameters inside my solver and each time I want to update the state. You could precompute everything in the .step, but then the terms have to know exactly what the solver is going to need in order to compute everything (concretely, if some solver calls my stateful function with updated states 3 times, then the terms would have to know that?), which doesn't seem generalizable.
It seems like term statefulness manifests via managed term state, but also accepting state parameters via vf/contr. I haven't really thought it out beyond the contr work, but that's my initial thoughts (also the term.step/solver.step seems a bit like the init apply sort of style, but that could just be a bias of mine that if I want to call term.vf and use a different state, I tend towards function inputs over change the terms object to have that state embedded).
Hi all, thanks for all your work!
In the suggestions in #462 and here, it seems that the auxiliary state is only updated at the level of step: the vector field will use the same value of the the the auxiliary state for each vector field evaluation within a step.
However, in my research I have a dynamical system:
$\dot{y} = V[t, y; w(t)],$
where the evolution of the auxiliary variable $w(t)$ in time is defined by a deterministic dynamical map:
$w(t) = F[t; t_0, w(t_0)].$
Here, $w(t)$ may or may not be differentiable. Nevertheless, it is a function of actual physical time and should thus be re-evaluated for every call to the vector field, not just at every step. In one step of the RK4 solver, for instance, there would be four evaluations of the vector field with different values of $w$, and so $w$ would need to be updated within a step.
I know the statement was made in #462 that "[a]nything not part of the differentiable state can only be made sense of at the step-by-step level", but I don't see a reason why we couldn't have an auxiliary variable that changes at the level of the vector field itself? I would need exactly that, and the OP also said that they needed to do the update "at every call to the vector_field".
In the language above, that would be a ``stateful vector field''.
P.S. I now realize that the solution above does almost exactly what I want, since in principle I could use it to solve $w(t)$ only on the interval $[t_0, t_1]$ at each step, interpolate, and feed the solution to the vector field. But for memory reasons it would be much simpler if I could think of a stateful vector field.
In the case you have here then I don't think you need any state at all: just evaluate w(t) within your vector field. It's only a function of time t, which is one of the arguments available in the vector field.
Sorry, I think I committed an abuse of notation and wasn't clear enough. Initially, I defined the function through the map:
$w(t) = F[t;t_0, w(t_0)]$.
To obtain the value at time $t$, I need the value of $w(t_0)$ at the last time $t_0$ at which the vector field was evaluated.
Of course, I could just build all the $w(t)$ at predefined times, interpolate to obtain a function $w(t)$, and then do what you suggest. But this is very memory inefficient in my case, and I'd rather have the value of $w(t)$ be computed using the previous evaluation on the fly.
Suppose I want to do a RK4 step for $y$. Then the sequence within the step would go something like this:
-
Start with $[y(t_0), w(t_0)]$ at $t_0$.
-
Evaluate the vector field to get $k_1 = V[t, y(t_0); w(t_0)]$.
-
Calculate $w\left[t_0+\frac{h}{2}\right]$ by progating $w(t_0)$, i.e., $w\left[t_0+\frac{h}{2}\right] = F\left[t_0+\frac{h}{2}; t_0, w(t_0)\right]$.
-
Evaluate the vector field to get $k_2 = V\left[t+\frac{h}{2}, y(t_0) + \frac{h}{2}k_1; w\left(t_0 + \frac{h}{2}\right)\right]$
etc ...
It's true that for differentiable $w(t_0)$, I can just include $w$ as a leaf of $y$ and write an equation for $\dot{w}(t)$. And that works well enough [Note]. But I have a case where I'm given a function $F$ that updates $w$ discretely within a step. And in that case, the most natural for me was to think in terms of a stateful vector field, where I update $w(t)$ at each call of the vector field.
That said, as I write this, I see how it's not very natural to think of a discrete update within a step, since this is the differential part of the solver. It's probably better if I can reformulate the problem.
[Note] Even in the differentiable case, the time scales of $y(t)$ and $w(t)$ can be very different, so it's not clear that it's the best approach. I guess the more conceptually satisfying solution would be to build a custom solver step where a multirate solver that propagates $w(t)$ at its own rate within the step (see #661 ).
I actually have the same use-case, where I need to pass along a discrete state that gets updated at each vector_field call rather than each step of the solver.
It might be that it's generally not a good idea to design a system that way, but in my case the goal is to port an existing numpy/scipy implementation into JAX, where taking advantage of XLA is the main motivation rather than autodiff. I would like to replicate the original behavior as close as numerically possible, ideally exactly, and that does require updating state in vector_field in this case.
Just like @lockwo said, my vector_field signature is (t, y, args, state) -> (y', state), where args are static args and state might change at every evaluation, depending on both t and its own previous value.
For now, to get things working, I just hacked together enough modifications to a bunch of function signatures in AbstractRungeKutta and ODETerm to support this additional passing around of state between subsequent calls to rk_stage.
I wonder if it would be possible for diffrax to provide more general support for stateful vector fields like this without requiring a bunch of breaking changes? It is a bit of a niche use case but it seems to pop up in a few different contexts.
@dsuedholt
For now, to get things working, I just hacked together enough modifications to a bunch of function signatures in
AbstractRungeKuttaandODETermto support this additional passing around of state between subsequent calls tork_stage.
I would be interested in a minimal working example of how you do this if you have one handy. Did it turn out to be complicated?
@BenjaminDAnjou sure. I wouldn't say complicated, just not particularly pretty or generalizable. It's very much a bandaid hack, but it does what I need it to do, i.e. replicate the behavior of an existing implementation. In my case, I just needed to make it work for a Dopri5 solver and an ODETerm without controls. For different terms or implicit solvers, you'd probably need to make different / more modifications.
My ODETerm.vf now looks like
def vf(self, t: RealScalarLike, y: Y, args: Args, vf_state):
out, vf_state = self.vector_field(t, y, args, vf_state)
[...]
return jtu.tree_map(_broadcast_and_upcast, out, y), vf_state
And then from there it was just about seeing which places in AbstractRungeKutta now throw errors about structure mismatch, and adjusting them to handle the state passing.
Solver initialization takes an additional vf_state argument and (by unpacking *f0) returns a 3-tuple (first_step, f0, vf_state) as the initial solver_state:
Solver init modifications (click to expand)
def func(
self,
terms: AbstractTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
vf_state,
) -> VF:
return terms.vf(t0, y0, args, vf_state)
def init(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
vf_state,
) -> _SolverState:
_, fsal = self._common(terms, t0, t1, y0, args)
if fsal:
first_step = jnp.array(True)
f0 = sentinel
[...]
if f0 is sentinel:
f0 = eqxi.eval_zero(self.func, terms, t0, y0, args, vf_state)
return first_step, *f0
else:
return None
That makes vf_state part of solver_state. Then the signature changes need to be propagated throughout AbstractRungeKutta.step. Basically just making sure that whenever vf is called (again, just for my case of an ERK solver without a control), the actual value of the field and the state are separated out, and that the state is passed around the solver loop.
Solver step modifications (click to expand)
def step(...):
[...]
def vf(t, y, vf_state, *, implicit_val):
_assert_same_structure(y, y0)
_vf = lambda term_i, t_i: term_i.vf(t_i, y, args, vf_state)
out, vf_state = t_map(_vf, terms, t, implicit_val=implicit_val)
if f0 is not _unused:
_assert_same_structure(out, f0)
return out, vf_state
[...]
if fsal:
assert solver_state is not None
first_step, f0, vf_state = solver_state
[...]
def rk_stage(val):
stage_index, _, vf_state, _, dyn_jac_f, dyn_jac_k, fs, ks, result = val
[...]
if eval_fs:
assert not vf_expensive
assert implicit_fi is not _unused
fi, vf_state = vf(ti, yi, vf_state, implicit_val=implicit_fi)
[...]
return (
stage_index + 1,
yi,
vf_state,
f1_for_fsal,
dyn_jac_f,
dyn_jac_k,
fs,
ks,
result,
)
[...]
init_val = (
init_stage_index,
y0,
vf_state,
dummy_f,
dyn_jac_f,
dyn_jac_k,
fs,
ks,
RESULTS.successful,
)
[...]
_, y1, vf_state, f1_for_fsal, _, _, fs, ks, result = final_val
[...]
if fsal:
new_solver_state = False, f1_for_fsal, vf_state
else:
new_solver_state = None
return y1, y_error, dense_info, new_solver_state, result
Then you can have
def vector_field(t, y, args, state):
jax.debug.print("{t.2f}: {state}", t=t, state=state)
return -y, t + state ** t
term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
t0 = 0.
t1 = 1.
y0 = jnp.array(1.0)
args = None
vf_state = jnp.array(0.0)
solver_state = solver.init(term, t0, t1, y0, args, vf_state)
solver.step(term, t0, t1, y0, args, solver_state, made_jump=False)
which prints
0.00: 0.0
0.20: 1.0
0.30: 1.2
0.80: 1.3562199684392582
0.89: 2.076039468639859
1.00: 2.8030884186162917
1.00: 3.8030884186162917
In retrospect, for a use case as specific as this, and using a standard Dopri5 solver, it probably would have been easier to make these modifications to jax.experimental.ode.odeint instead.