diffrax
diffrax copied to clipboard
Support for uncertainty propagation?
It would be amazing to be able to specify uncertainties on ICs and have that propagate through to the solutions. I know this is a very difficult problem in general. Off the top of my head some challenges are:
- Gaussian uncertainties are hard, but operations on arbitrary (e.g. non-symmetric) distributions often don't even have an analytic form. This could be approximated to 1st order...
- What's the new API look like?
- How to support dense solutions?
Thankfully I think at least point 2 has a workable solution. @patrick-kidger, you've written quax to allow for array-ish objects in JAX. My suggestion would be to make a diffeqsolve(y0=) accept quax classes that handle the distribution and its propagation.
Point 1 still remains hard, but there's a still-useful starting point. The simplest "uncertainty" to support isn't even Gaussian but a simple lower and upper bound interval. That would be a good proof of concept but still useful! I know that the same result could be accomplished by doing diffeqsolve twice, but a) the unified API would be a convenience and b) we could hopefully subsequently implement Gaussian and more complex distributions.
To use the opening example from https://docs.kidger.site/diffrax/usage/getting-started/
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
from diffrax import Interval
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=Interval(0.9, 1.1), # note the Interval
saveat=saveat, stepsize_controller=stepsize_controller)
print(sol.ts) # DeviceArray([0. , 1. , 2. , 3. ])
print(sol.ys) # Interval(...) # IDK about the internals
As a related note, having quax classes would also enable a nice bundling of arrays of y0 into a MonteCarloMeasurement approximation of an uncertainty distribution:
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=MCMeasurement(...),
saveat=saveat, stepsize_controller=stepsize_controller)
print(sol.ts) # DeviceArray([0. , 1. , 2. , 3. ])
print(sol.ys) # MCMeasurement(...)
print(sol.ys.mean()) # DeviceArray([1. , 0.368, 0.135, 0.0498])
print(sol.ys.std()) # DeviceArray([...])
I think this'd be really cool!
So as you've highlighted, Quax is the way to go about implementing something like this. The way this would work is to provide overloads for all primitives that diffeqsolve uses, and then call quax.quaxify(diffrax.diffeqsolve)(y0=Interval(...), ...).
In particular the point of a library like Quax is that this shouldn't require changing Diffrax at all.
Providing all of those overloads is probably fairly ambitious (and might stress-test just how load-bearing Quax really is 😅), but if you do something like that I'd love to see it!
As a simple example:
class ExactMeasurement(quax.ArrayValue):
array: ArrayLike = eqx.field(converter=jnp.asarray)
def aval(self):
shape = jnp.shape(self.array)
dtype = jnp.result_type(self.array)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
msg = "Refusing to materialise."
raise ValueError(msg)
@staticmethod
def default(
primitive: jax.core.Primitive,
values: Sequence[ArrayLike | quax.Value],
params: dict,
):
raw_values: list[ArrayLike] = []
for value in values:
if eqx.is_array_like(value):
raw_values.append(value)
elif isinstance(value, ExactMeasurement):
raw_values.append(value.array)
elif isinstance(value, quax.Value):
raise NotImplementedError
else:
raise AssertionError
out = primitive.bind(*raw_values, **params)
return (
[ExactMeasurement(x) for x in out]
if primitive.multiple_results
else ExactMeasurement(out)
)
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0.0, 1.0, 2.0, 3.0])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = quax.quaxify(diffeqsolve)(
term,
solver,
t0=0,
t1=3,
dt0=0.1,
y0=ExactMeasurement(1.0),
saveat=saveat,
stepsize_controller=stepsize_controller,
)
print(sol.ts) # ExactMeasurement(array=f32[4])
print(sol.ys) # ExactMeasurement(array=f32[4])
Obviously only having the default rule is too permissive since sol.ts is now an ExactMeasurement, but this shows it's possible.