diffrax
diffrax copied to clipboard
Forcing solver to stay in given region
I'm solving a system of ODEs that simulates concentrations of some substances. I know that the concentrations have to be real numbers in the range [0, 1]. Normally, after each step, I would simply clamp the values to be in that region (since values outside of that range can mess with the simulation). Is there a way to achieve this nicely with diffrax?
I'm trying out a pretty wide range of parameters (since I'm using the simulations for MCMC), so ramping up the accuracy and step sizes to handle the outliers seems wasteful.
There's a couple of ways you could do this. The first, as you say, is to just clamp the values to the desired region. You can do this by wrapping the solver:
class ClampSolver(dfx.AbstractSolver):
solver: dfx.AbstractSolver
clamp: Callable
def step(...):
y1, ... = self.solver.step(step)
y1 = self.clamp(y1)
return y1, ...
# also forward any other methods to `self.solver` as needed
clamp = lambda y: jnp.clip(y, a_min=0, a_max=1)
solver = ClampSolver(dfx.Tsit5(), clamp)
Another way (that will give more accurate solutions) is to have the step size controller reject any steps that are out of bounds (which we do by setting the error estimate to infinity). This assumes you're using an adaptive stepsize controller such as PIDController
.
class InBoundsSolver(dfx.AbtractSolver):
solver: dfx.AbstractSolver
out_of_bounds: Callable
def step(...):
y1, y_error, ... = self.solver.step(...)
oob = self.out_of_bounds(y1)
keep = lambda y: jnp.where(oob, jnp.inf, y)
y_error = jax.tree_util.tree_map(keep, y_error)
return y1, y_error, ...
# also forward any other methods to `self.solver` as needed
out_of_bounds = lambda y: (y < 0) | (y > 1)
solver = InBoundsSolver(dfx.Tsit5(), out_of_bounds)
Eventually I intend to provide built-in support for this kind of thing, so I'm also going to mark this as a feature request.
Thank you! I really like the second solution.