diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Forcing solver to stay in given region

Open grfrederic opened this issue 2 years ago • 2 comments

I'm solving a system of ODEs that simulates concentrations of some substances. I know that the concentrations have to be real numbers in the range [0, 1]. Normally, after each step, I would simply clamp the values to be in that region (since values outside of that range can mess with the simulation). Is there a way to achieve this nicely with diffrax?

I'm trying out a pretty wide range of parameters (since I'm using the simulations for MCMC), so ramping up the accuracy and step sizes to handle the outliers seems wasteful.

grfrederic avatar Dec 04 '22 11:12 grfrederic

There's a couple of ways you could do this. The first, as you say, is to just clamp the values to the desired region. You can do this by wrapping the solver:

class ClampSolver(dfx.AbstractSolver):
    solver: dfx.AbstractSolver
    clamp: Callable

    def step(...):
        y1, ... = self.solver.step(step)
        y1 = self.clamp(y1)
        return y1, ...

    # also forward any other methods to `self.solver` as needed

clamp = lambda y: jnp.clip(y, a_min=0, a_max=1)
solver = ClampSolver(dfx.Tsit5(), clamp)

Another way (that will give more accurate solutions) is to have the step size controller reject any steps that are out of bounds (which we do by setting the error estimate to infinity). This assumes you're using an adaptive stepsize controller such as PIDController.

class InBoundsSolver(dfx.AbtractSolver):
    solver: dfx.AbstractSolver
    out_of_bounds: Callable

    def step(...):
        y1, y_error, ... = self.solver.step(...)
        oob = self.out_of_bounds(y1)
        keep = lambda y: jnp.where(oob, jnp.inf, y)
        y_error = jax.tree_util.tree_map(keep, y_error)
        return y1, y_error, ...

    # also forward any other methods to `self.solver` as needed

out_of_bounds = lambda y: (y < 0) | (y > 1)
solver = InBoundsSolver(dfx.Tsit5(), out_of_bounds)

Eventually I intend to provide built-in support for this kind of thing, so I'm also going to mark this as a feature request.

patrick-kidger avatar Dec 07 '22 20:12 patrick-kidger

Thank you! I really like the second solution.

grfrederic avatar Dec 14 '22 09:12 grfrederic