diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Approaching multi trajectory adaptive stepping

Open lockwo opened this issue 1 year ago • 1 comments

In line with some of the weak solvers we are working on to get a PR for in diffrax, there are a variety of adaptive methods that we are implementing. One of the schemes rely on estimating errors by looking at multiple trajectories (https://onlinelibrary.wiley.com/doi/abs/10.1002/pamm.200410005), like you estimate some quantity from simultaneous trajectories.

I wanted to think how to best integrate this into diffrax philosophically, since this code works as a wrapper on top of it, but isn't as trivial to implement in the framework itself. Since integrate.py conceptually works over a single trajectory, to get multiple the solution is usually just to vmap, so I was thinking of playing around inside that and making a unvmap version of the computations that we needed (but that seemed very hacky to define custom unvmaps). I was curious if you had thought about this more and had opinions on multi trajectory reliant adaptive schemes?

lockwo avatar Aug 09 '24 20:08 lockwo

Hmm. You've got a couple of options I think. First of all would be to bundle multiple trajectories together into one gigantic vector field (with each piece independent of the others). Diffrax just sees a single integration like normal. This would mean that a batch of solves would get fairly gigantic (each batch element has its own 'inner batch' of trajectories). It would preserve batch independence, however.

The alternative would be to reach across the batch and explicitly create a cross-batch dependence. JAX provides tools to do this in the form of jax.lax.p{sum, ...}. Take a look at eqx.nn.BatchNorm for an example. Typically you name a particular vmap.

patrick-kidger avatar Aug 10 '24 16:08 patrick-kidger

In the end the alternative approach was what we went with (since it had precedent in equinox, and also came up in some discussions on jax issues with the core team as a recommended method). Turns out the RI weak error estimate scheme (which were only proved for 1D noise, but people just use them for other problems e.g. https://docs.sciml.ai/DiffEqDocs/stable/solvers/sde_solve/#sde_solve) was pretty low performing compared to other weak solver approaches (but it'll be many months before these PRs are opened).

lockwo avatar Dec 14 '24 06:12 lockwo