diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Add benchmarks against `jax.experimental.ode`, torchdiffeq, DifferentialEquations.jl

Open patrick-kidger opened this issue 2 years ago • 3 comments

We now have a couple of simple ODE benchmarks, but could definitely afford to add a lot more.

patrick-kidger avatar Aug 17 '21 00:08 patrick-kidger

What sort of benchmarks would be useful for people looking into using these tools? Given that diffrax, JAX, and torchdiffeq are heavily integrated/associated with machine/deep learning, it would make sense to benchmark problems that one would encounter in these fields. More broadly, if one is looking to solve a machine learning problem involving differential equations, one generally evaluates the ecosystems of the options as a whole beyond just the diffeq library.

Because of this, it is my suggestion that we provide benchmarks in the form of a differential equation problem that fits into a wider machine learning problem. We could first present the timings of the diffeq part, then the wider ML problem as a whole. A simple example of this could be calibrating a Lotka-Volterra model with MCMC, a problem which has the advantage of already being implemented in Turing and already being benchmarked by diffrax.

I understand that this will require a fair bit of effort, which is why I wanted to discuss whether this is worthwhile and which benchmarks would be the most important before jumping into it.

Some other low hanging fruit would simply be to find the examples in the documentation of the respective "competitor" packages and translating them into diffrax.

jacobusmmsmit avatar Aug 19 '22 11:08 jacobusmmsmit

I think all of the above sound like reasonable benchmarks. I would suggest focusing primarily on the diffeq part of any problem though, rather than the entire ML system. It's really the first thing we care about, and the latter introduces a lot of extra maintenance overhead.

Translating examples certainly sounds good like a good first place to start. I would caution against reading too much into these, though -- a lot of examples are written for pedagogy rather than speed. Moreover such examples are often small, but for torchdiffeq/torchsde/torchcde the focus is really on large-scale neural network problems, as PyTorch is intrinsically overhead-bound for small problems.

patrick-kidger avatar Aug 22 '22 13:08 patrick-kidger

I was trying to understand the current status of training hybrid neural ode in Julia and Jax. I found this old benchmark https://gist.github.com/ChrisRackauckas/62a063f23cccf3a55a4ac9f6e497739a. Maybe this could be a good time to revisit the Jax part? For the Julia part, there are alternative ways to implement as well, but I don't expect the performance will change that much.

jiweiqi avatar Oct 07 '22 13:10 jiweiqi