diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

How to enforce non-negativity constraints?

Open Mycroft-47 opened this issue 3 months ago • 2 comments

I'm solving an ODE system for viral dynamics using Kvaerno4 with a PID controller. The state variables need to stay non-negative....

I searched the docs and issues but didn't find anything about handling non-negativity constraints. Currently forcing states to be non-negative by clamping them to zero:

def ode_system(t, state, params):
    state = jnp.maximum(state, 0.0) # <-- this is my current workaround
    # ..... compute derivatives ....
    return derivatives

But this is generally a bad approach since it interferes with error control and introduces discontinuities, which is why I'm here asking for guidance;

Looking at how MATLAB's ode15s handles this - it supports non-negativity through odeset with the NonNegative option. The solver does two things: first, it wraps the ODE function to modify derivatives, and second, it incorporates constraint violations into error estimation for step acceptance/rejection. Inspecting odenonnegative.m reveals the derivative modification approach:

function yp = local_odeFcn_nonnegative(idxNonNegative, ode, t, y, varargin)
    yp = feval(ode, t, y, varargin{:});
    ndx = idxNonNegative(find(y(idxNonNegative) <= 0));
    yp(ndx) = max(yp(ndx), 0); % <-- here
end

Then during the main integration loop, after computing a candidate solution step, the solver checks if any constrained variables went negative and computes an additional error term that can trigger step rejection.

Is there a recommended pattern for this in Diffrax? I'm still fairly new to JAX and Diffrax, so don't have enough understanding of the internals to implement something similar myself. Would appreciate any pointers or if there's an existing approach I'm missing

Mycroft-47 avatar Oct 05 '25 22:10 Mycroft-47

Yup, this is totally possible. Take a look at the InBoundsSolver of https://github.com/patrick-kidger/diffrax/issues/200.

(Which we should probably add to Diffrax's API directly!)

patrick-kidger avatar Oct 06 '25 12:10 patrick-kidger

Thanks... that InBoundsSolver approach is exactly what I needed to see, I ended up implementing something similar that wraps the ODE function to clamp derivatives when variables hit zero:

class NonNegativeWrapper:
    def __init__(self, ode_func, nonneg_indices=None):
        self.ode_func = ode_func
        self.nonneg_indices = nonneg_indices

    def __call__(self, t, y, args):
        dydt = self.ode_func(t, y, args)

        if self.nonneg_indices is None:
            mask = y <= 0
            dydt = jax.numpy.where(mask, jax.numpy.maximum(dydt, 0), dydt)
        else:
            mask = y[self.nonneg_indices] <= 0
            dydt = dydt.at[self.nonneg_indices].set(
                jax.numpy.where(
                    mask,
                    jax.numpy.maximum(dydt[self.nonneg_indices], 0),
                    dydt[self.nonneg_indices],
                )
            )

        return dydt

Usage is straightforward:

wrapped_ode = NonNegativeWrapper(
    lambda t, y, args: ode_system(t, y, args),
    nonneg_indices=None # or specify indices like [0, 1, 3]
)

term = diffrax.ODETerm(wrapped_oe)
solution = diffrax.diffeqsolve(
    terms=term,
    solver=diffrax.Kvaerno4(root_finder=optimistix.Newton(rtol=1e-8, atol=1e-10)),
    t0=0.0,
    t1=20.0,
    dt0=0.01,
    y0=initial_state,
    args=params,
    saveat=diffrax.SaveAt(ts=t_eval),
    stepsize_controller=diffrax.PIDController(rtol=1e-8, atol=1e-10),
    max_steps=1600,
)

Would definitely be useful to have something similar in the API directly...

Mycroft-47 avatar Oct 06 '25 23:10 Mycroft-47