How to enforce non-negativity constraints?
I'm solving an ODE system for viral dynamics using Kvaerno4 with a PID controller. The state variables need to stay non-negative....
I searched the docs and issues but didn't find anything about handling non-negativity constraints. Currently forcing states to be non-negative by clamping them to zero:
def ode_system(t, state, params):
state = jnp.maximum(state, 0.0) # <-- this is my current workaround
# ..... compute derivatives ....
return derivatives
But this is generally a bad approach since it interferes with error control and introduces discontinuities, which is why I'm here asking for guidance;
Looking at how MATLAB's ode15s handles this - it supports non-negativity through odeset with the NonNegative option. The solver does two things: first, it wraps the ODE function to modify derivatives, and second, it incorporates constraint violations into error estimation for step acceptance/rejection. Inspecting odenonnegative.m reveals the derivative modification approach:
function yp = local_odeFcn_nonnegative(idxNonNegative, ode, t, y, varargin)
yp = feval(ode, t, y, varargin{:});
ndx = idxNonNegative(find(y(idxNonNegative) <= 0));
yp(ndx) = max(yp(ndx), 0); % <-- here
end
Then during the main integration loop, after computing a candidate solution step, the solver checks if any constrained variables went negative and computes an additional error term that can trigger step rejection.
Is there a recommended pattern for this in Diffrax? I'm still fairly new to JAX and Diffrax, so don't have enough understanding of the internals to implement something similar myself. Would appreciate any pointers or if there's an existing approach I'm missing
Yup, this is totally possible. Take a look at the InBoundsSolver of https://github.com/patrick-kidger/diffrax/issues/200.
(Which we should probably add to Diffrax's API directly!)
Thanks... that InBoundsSolver approach is exactly what I needed to see, I ended up implementing something similar that wraps the ODE function to clamp derivatives when variables hit zero:
class NonNegativeWrapper:
def __init__(self, ode_func, nonneg_indices=None):
self.ode_func = ode_func
self.nonneg_indices = nonneg_indices
def __call__(self, t, y, args):
dydt = self.ode_func(t, y, args)
if self.nonneg_indices is None:
mask = y <= 0
dydt = jax.numpy.where(mask, jax.numpy.maximum(dydt, 0), dydt)
else:
mask = y[self.nonneg_indices] <= 0
dydt = dydt.at[self.nonneg_indices].set(
jax.numpy.where(
mask,
jax.numpy.maximum(dydt[self.nonneg_indices], 0),
dydt[self.nonneg_indices],
)
)
return dydt
Usage is straightforward:
wrapped_ode = NonNegativeWrapper(
lambda t, y, args: ode_system(t, y, args),
nonneg_indices=None # or specify indices like [0, 1, 3]
)
term = diffrax.ODETerm(wrapped_oe)
solution = diffrax.diffeqsolve(
terms=term,
solver=diffrax.Kvaerno4(root_finder=optimistix.Newton(rtol=1e-8, atol=1e-10)),
t0=0.0,
t1=20.0,
dt0=0.01,
y0=initial_state,
args=params,
saveat=diffrax.SaveAt(ts=t_eval),
stepsize_controller=diffrax.PIDController(rtol=1e-8, atol=1e-10),
max_steps=1600,
)
Would definitely be useful to have something similar in the API directly...