diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Accelerate ODE solver [What did I miss?]

Open zhengqigao opened this issue 1 year ago • 3 comments

Hi,

I am playing around with diffrax's ODE solving functionality. In a nutshell, I define a simple feedforward MLP with random initialization and benchmark the runtime of using it as the temporal derivatives of an ODE. I wrote the following code to record the run-time of ODE solving and got run-time around 3.7 sec, which seems much slower compared to other ODE solver frameworks.

I am new to jax and diffrax. What did I miss in my code implemenation?

import equinox as eqx
import jax
import diffrax
import jax.numpy as jnp
import time


class MLPeqx(eqx.Module):
    layers: list
    activation: callable = eqx.static_field()

    def __init__(self, hidden_dims):
        super().__init__()
        tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
        self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in
                       range(len(hidden_dims) - 1)]
        self.activation = jax.nn.relu

    def __call__(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        x = self.layers[-1](x)
        return x


class ODEjax(eqx.Module):
    func: MLPeqx

    def __init__(self, hidden_dims):
        super().__init__()
        self.func = MLPeqx(hidden_dims)

    def __call__(self, t, y, args=None):
        return self.func(y)


def solve_ode(input_x, t, func, cfg):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(func),
        cfg['method'],
        t0=t[0],
        t1=t[-1],
        y0=input_x,
        dt0=None,
        saveat=diffrax.SaveAt(ts=t),
        stepsize_controller=diffrax.PIDController(atol=cfg['atol'], rtol=cfg['rtol']),
    )
    return sol.ys


def run_diffrax(hidden_dims, input_x, t, num_t, cfg):
    t = jnp.linspace(t[0], t[1], num_t)
    func = ODEjax(hidden_dims)
    y = jax.vmap(solve_ode, in_axes=(0, None, None, None))(input_x, t, func, cfg)
    return y


if __name__ == '__main__':
    batch_size = 128
    hidden_dims = [100, 100, 100]
    input_x = jax.random.normal(jax.random.PRNGKey(0), (128, 100))

    start_time = time.time()
    run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
        'method': diffrax.Dopri5(),
        'atol': 1e-5,
        'rtol': 1e-5})
    end_time = time.time()

    print(f"run time = {end_time - start_time:.3f} (sec)")

zhengqigao avatar Jul 23 '24 14:07 zhengqigao

I recommending checking out Jax's docs on benchmarking https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code, the tldr for this example is that:

  1. jax compile times will be longer for first iteration (and are generally excluded in benchmarks)
  2. with async dispatch you need a block until ready

With the following code I got: 19.4 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

import equinox as eqx
import jax
import diffrax
import jax.numpy as jnp
import time


class MLPeqx(eqx.Module):
    layers: list
    activation: callable = eqx.static_field()

    def __init__(self, hidden_dims):
        super().__init__()
        tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
        self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for I in
                       range(len(hidden_dims) - 1)]
        self.activation = jax.nn.relu

    def __call__(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        x = self.layers[-1](x)
        return x


class ODEjax(eqx.Module):
    func: MLPeqx

    def __init__(self, hidden_dims):
        super().__init__()
        self.func = MLPeqx(hidden_dims)

    def __call__(self, t, y, args=None):
        return self.func(y)


def solve_ode(input_x, t, func, cfg):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(func),
        cfg['method'],
        t0=t[0],
        t1=t[-1],
        y0=input_x,
        dt0=None,
        saveat=diffrax.SaveAt(ts=t),
        stepsize_controller=diffrax.PIDController(atol=cfg['atol'], rtol=cfg['rtol']),
    )
    return sol.ys

@eqx.filter_jit
def run_diffrax(hidden_dims, input_x, t, num_t, cfg):
    t = jnp.linspace(t[0], t[1], num_t)
    func = ODEjax(hidden_dims)
    y = jax.vmap(solve_ode, in_axes=(0, None, None, None))(input_x, t, func, cfg)
    return y

batch_size = 128
hidden_dims = [100, 100, 100]
input_x = jax.random.normal(jax.random.PRNGKey(0), (128, 100))

_ = run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

%%timeit
_ = run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

lockwo avatar Jul 23 '24 16:07 lockwo

Thanks so much! I have tried on my end and observed similar run-time metrics. I have another follow-up question. Say I first want to run with atol=rtol=1e-5, and later in my code I want it to run with atol=rtol=1e-4. I observe again the method run_diffrax runs slower again when changing from 1e-5 to 1e-4 because of compilation(I guess). Namely,

# first time of atol=rtol=1e-5, takes ~2secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

# second time of atol=rtol=1e-5, takes ~0.008secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

# first time of atol=rtol=1e-4, takes ~2secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

# second time of atol=rtol=1e-4, takes ~0.008secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready(

Is this behavior expected? I wonder if there is a way to compile only once for arbitrary atol=rtol values, and can always run around millisecond level regardless of atol and rtol.

zhengqigao avatar Jul 24 '24 15:07 zhengqigao

Yes, this behavior is expected. The python floats are getting marked as static by the filtering that happens before jit. You can make them not static by making them jax types (e.g. arrays).

start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-5),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")


start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-5),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")


start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-4),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")


start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-4),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")
run time = 4.057 (sec)
run time = 0.016 (sec)
run time = 0.013 (sec)
run time = 0.013 (sec)

lockwo avatar Jul 24 '24 17:07 lockwo