diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Significant performance difference: diffeqsolve vs. lax.scan - Expected Behavior?

Open AliAbdelwanis opened this issue 10 months ago • 4 comments

Hi,

I appreciate all the work that has gone into Diffrax! I'm trying to use diffrax to simulate the interaction between an NN policy and an ode system model for a given horizon length. Specifically, the policy takes the system state as input at each step and generate the control action that is to be applied to the system. Then the ode is solved for obtaining the evolution of the state. For that I have two implementations. The first one is by using lax.scan, and in each iteration the ode is solved for one step using solver.init() and solver.step(). In the second implementation, I use Diffrax.diffeqsolve() which takes as arguments the final simulation time and saveAt(ts = jnp.linespace(t0,t1, num_sim_steps) in addition to term, solver, ..etc. For both implementations I use 'Diffrax.Euler' as solver and I augment the ODE with the policy network inside the vector-field. However, I noticed that the simulation with Diffrax.diffeqsolve() is almost 3 times slower than with lax.scan and single steps, and this difference gets even bigger when comparing Diffrax.diffeqsolve() with 'Adaptive steps' and lax.scan with 'fixed single steps'. The idea behind using Diffrax.diffeqsolve() is that I wanted to investigate if reducing the number of steps, by adapting step sizes, would improve the simulation speed even in the presence of the overheads resulted from intermediate solver-related calculations and rejected steps when exceeding tolerances. But what I don't understand is that when using Euler I would expect both implementations to have similar simulation speeds, which is not the case. My second question would be is there a way to utilize adaptive solver to make the simulation faster for this application.

Here is a benchmark example for this comparison in which I replaced the actual system with an arbitrary first-order ODE. I use Python (3.12.7), Jax (0.4.35), equinox (0.11.10), and diffrax (0.5.1). I run the code on CPU 13th Gen Intel i7-1355U.

Imports

import jax
import jax.numpy as jnp
import numpy as np

import equinox as eqx
import optax
import diffrax

import timeit

Define policy network class

class MLP(eqx.Module):
    """ class for a policy for providing control actions.
    """
    layers: list[eqx.nn.Linear]

    def __init__(self, layer_sizes, key):
        self.layers = []
        for fan_in, fan_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            key, subkey = jax.random.split(key)
            self.layers.append(eqx.nn.Linear(fan_in, fan_out, use_bias=True, key=subkey))

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.leaky_relu(layer(x))
        return self.layers[-1](x)

case 1: lax.scan and single-steps

def ode_step(init_state, ref, policy):
     """Method for simulating the policy-environment interaction and solving the ode for one step
 
     Args:
         init_state (jax.Array): state at the current step (y0)
         ref (jax.Array): reference state at the current step
         policy (MLP): for predicting the control action 
 
     Returns:
         jax.Array: updated state after one step simulation
     """
     args = (ref)
 
     # ode, including policy (MLP) and system: tau*dy/dt + y = u, aumming tau=1
     d_y = lambda t,y,args: (policy(jnp.concatenate([y, args]))-y)                                        
 
     term = diffrax.ODETerm(d_y)
     solver = diffrax.Euler()
     t0 = 0
     t1 = 1e-4
     y0 = init_state
     env_state = solver.init(term, t0, t1, y0, args)
     y, _, _, env_state, _ = solver.step(term, t0, t1, y0, args, env_state, made_jump=False)
     return y

def rollout_traj_scan(policy, init_states, ref_states, horizon_length):
     """rollout policy-environment interaction for 'horizon_length' for single sample using lax.scan
 
     Args:
         policy (MLP): predicting control actions
         init_states (jax.Array): initial environment state before rollout
         ref_states (jax.Array): reference to be tracked
         horizon_length (int): length of future predictions
 
     Returns:
         jax.Array: MSE tracking loss
     """
     # extending ref_states to horizon length
     ref_o = jnp.repeat(ref_states[None, :], horizon_length, axis=0)                                      
     
     def body_fun(carry, ref):
         state = carry
         state = ode_step(carry, ref, policy)
         return (state), (state)
 
     _, (states) = jax.lax.scan(body_fun, (init_states), ref_o, horizon_length)
 
     # error between the simulation state and reference
     error= states-ref_states                                                                           
 
     loss=jnp.mean((error)**2)                                                                               
     return jnp.clip(loss, max=1e5)

Case 2: Diffrax.diffeqsolve

def ode_diffeqsolve(policy, init_state, ref_state,  horizon_length):
      """Method for rollout, simulating the policy-environment interaction, and solving the ode for 'horizon_length' usin diffrax.diffeqsolve.
  
      Args:
          policy (MLP): predicting control actions
          init_state (jax.Array): initial environment state before rollout
          ref (jax.Array): reference to be tracked
          horizon_length (int): length of future predictions
  
      Returns:
          jax.Array: state trajectory at predefined time steps
      """
      args = (ref_state)
  
      # ode, including policy (MLP) and system: tau*dy/dt + y = u, assuming tau=1
      d_y = lambda t,y,args: (policy(jnp.concatenate([y, args]))-y)                                        
                                                                                                           
      return diffrax.diffeqsolve(
          terms = diffrax.ODETerm(d_y), 
          solver = diffrax.Euler(), 
          t0 = 0, 
          t1 = horizon_length*1e-4, 
          dt0=1e-4, 
          y0=init_state, 
          args=args, 
          saveat=diffrax.SaveAt(ts=jnp.linspace(1e-4, horizon_length*1e-4, horizon_length)), 
          stepsize_controller=diffrax.ConstantStepSize()
      ).ys 

#Roll out using diffrax.diffeqsolve until t1
def rollout_traj_diffeqsolve(policy, init_states, ref_states, horizon_length):
      """Calls 'ode_diffeqsolve' for the rollout and calculate MSE tracking loss
  
      Args:
          policy (MLP): for predicting control actions
          init_states (jax.Array): initial environment state before rollout
          ref_states (jax.Array): reference to be tracked
          horizon_length (int): length of future predictions
  
      Returns:
          jax.Array: MSE tracking loss
      """
  
      # get state trajectory through horizon_length
      states = ode_diffeqsolve(policy, init_states, ref_states, horizon_length)                          
  
      # error between the simulation state and reference
      error= states-ref_states                                                                           
  
      loss=jnp.mean((error)**2)
      return jnp.clip(loss, max=1e5)
      

Setup for the benchmark

jax_key = jax.random.PRNGKey(np.random.randint(0, 2**31))
jax_key, policy_key = jax.random.split(jax_key)

policy=MLP([2,20,20,20,1],key=jax_key)  # initialize policy network

state_key, ref_key = jax.random.split(jax_key)
init_states= jax.random.uniform(state_key, minval=0.0, maxval=30.0, shape=(1,))  # generate random initial state
ref_states =  jax.random.uniform(ref_key, minval=0.0, maxval=30.0, shape=(1,))  # generate random reference state

horizon_length = 25  # rollout length
train_steps = 5000  # number of iterations to be measured

Function for testing the speed of fwd and bwd propagations of the system

def speedtest(fcn, name):
    
    fwd = eqx.filter_jit(fcn)
    bwd = eqx.filter_jit(eqx.filter_grad(fcn))

    #Measure fwd time for train_steps iterations
    fwd_times = timeit.repeat(
        lambda: jax.block_until_ready(fwd(
        policy,
        init_states,
        ref_states,
        horizon_length,
    )), number=train_steps, repeat=10
    )                                                                                                      
    print(f"{name} fwd: {min(fwd_times)}")

    #Measure fwd+Bwd time for train_steps iterations
    bwd_times = timeit.repeat(
        lambda: jax.block_until_ready(bwd(
        policy,
        init_states,
        ref_states,
        horizon_length,
    )), number=train_steps, repeat=10
    )                                                                                                      
    print(f"{name} fwd+bwd: {min(bwd_times)}")

Run tests

speedtest(rollout_traj_scan, "scan")
speedtest(rollout_traj_diffeqsolve, "diffeqsolve")

AliAbdelwanis avatar Feb 17 '25 18:02 AliAbdelwanis

There's a fair amount that could be said here, so first I'll just remark that with certain diffrax parameters you can come much closer to the lax.scan time. Specifically, by making it a bit more of an apples to apples comparison.

Originally your code had

scan fwd: 3.1685399269999834
scan fwd+bwd: 4.747113872
diffeqsolve fwd: 8.11498718200005
diffeqsolve fwd+bwd: 15.506070363000049

With this version (since we don't want any interpolation, and we want to save every step in memory):

      return diffrax.diffeqsolve(
          terms = diffrax.ODETerm(d_y), 
          solver = diffrax.Euler(), 
          t0 = 0, 
          t1 = horizon_length*1e-4, 
          dt0=1e-4, 
          y0=init_state, 
          args=args, 
          saveat=diffrax.SaveAt(steps=True),#ts=jnp.linspace(1e-4, horizon_length*1e-4, horizon_length)), 
          adjoint=diffrax.RecursiveCheckpointAdjoint(horizon_length),
          max_steps=horizon_length + 1,
          stepsize_controller=diffrax.ConstantStepSize(),
          throw=False,
      ).ys 

I see

diffeqsolve fwd: 3.5416423899999927
diffeqsolve fwd+bwd: 7.565716345000055

Certainly more room for optimization, but hopefully some ideas to go in that direction. There's also a fair amount of nuance when it comes to benchmarking diffrax (see https://github.com/patrick-kidger/diffrax/issues/592, https://github.com/patrick-kidger/diffrax/issues/549, https://github.com/patrick-kidger/diffrax/issues/517, https://github.com/patrick-kidger/diffrax/issues/179, https://github.com/patrick-kidger/diffrax/issues/82, etc.).

Regarding your point on speed vs adaptivity:

However, I noticed that the simulation with Diffrax.diffeqsolve() is almost 3 times slower than with lax.scan and single steps, and this difference gets even bigger when comparing Diffrax.diffeqsolve() with 'Adaptive steps' and lax.scan with 'fixed single steps'. The idea behind using Diffrax.diffeqsolve() is that I wanted to investigate if reducing the number of steps, by adapting step sizes, would improve the simulation speed even in the presence of the overheads resulted from intermediate solver-related calculations and rejected steps when exceeding tolerances.

There's a couple points here. First is that you want to make sure you are checking the number of steps the solver is actually taking/rejecting. Next is that you might be able to optimize the parameters of the PID controllers. Finally, the thing that is more important is error vs time. Since you are doing adaptivity with a solver that isn't Euler, and comparing to Euler, they will have different error rates. So you might want to create a plot of the error rate vs time (which is strongly correlated to the error rate vs step size/tolerances) to figure out the best regime for your problem.

lockwo avatar Feb 17 '25 20:02 lockwo

I'm away from my computer at the moment but stepsize_controller=diffrax.StepTo(...) should also be used here -- it's a subtle point but this, rather than ConstantStepSize, is the equivalent of a lax.scan.

Give that a try and let us know how it fares?

patrick-kidger avatar Feb 18 '25 07:02 patrick-kidger

Hi, Thanks for your replies and suggestions. Previously, the output I used to get is as follows

scan fwd: 1.2325921999999991
scan fwd+bwd: 2.7414333999995506
diffeqsolve fwd: 3.924495299999762
diffeqsolve fwd+bwd: 7.282417599999462

But after updating the arguments of diffeqsolve to

   return diffrax.diffeqsolve(
          terms = diffrax.ODETerm(d_y), 
          solver = diffrax.Euler(), 
          t0 = 0, 
          t1 = horizon_length*1e-4, 
          dt0=1e-4, 
          y0=init_state, 
          args=args, 
          saveat=diffrax.SaveAt(steps=True),
          adjoint=diffrax.RecursiveCheckpointAdjoint(horizon_length),
          max_steps=horizon_length + 1, 
          stepsize_controller=diffrax.ConstantStepSize(),
          throw=False,
      ).ys 

The speed was improved. Specially the fwd of diffeqsolve has become closer to lax.scan, but the fwd+bwd is still more than the double

scan fwd: 1.2437929000002441
scan fwd+bwd: 2.4098524000000907
diffeqsolve fwd: 1.6901795999997375
diffeqsolve fwd+bwd: 5.315119300000333

However, for some reason, addingstepsize_controller=diffrax.StepTo(...)instead of ConstantStepSize, as in the following, increased the time again.

 return diffrax.diffeqsolve(
          terms = diffrax.ODETerm(d_y), 
          solver = diffrax.Euler(), 
          t0 = 0, 
          t1 = horizon_length*1e-4, 
          dt0=None, 
          y0=init_state, 
          args=args, 
          saveat=diffrax.SaveAt(steps=True),
          adjoint=diffrax.RecursiveCheckpointAdjoint(horizon_length),
          max_steps=horizon_length + 1, 
          stepsize_controller=diffrax.StepTo(ts=jnp.linspace(0, horizon_length*1e-4, horizon_length)),
          throw=False,
      ).ys 
scan fwd: 1.2403713999997308
scan fwd+bwd: 2.598518999999669
diffeqsolve fwd: 3.682437399999799
diffeqsolve fwd+bwd: 7.916930800000046

AliAbdelwanis avatar Feb 18 '25 09:02 AliAbdelwanis

Okay, found the time to track this down.

First of all, here's a simplified benchmark script running the same ODE as the above (I skipped the losses that were the same for both versions), and demonstrates the two implementations getting the same performance!

import os
import timeit

os.environ["EQX_ON_ERROR"] = "nan"

import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr


def run(policy, y0, ref_y, horizon_length, diffeqsolve):
    term = diffrax.ODETerm(lambda t, y, args: policy(jnp.concatenate([y, ref_y])) - y)
    solver = diffrax.Euler()
    if diffeqsolve:
        t0 = 0
        t1 = 1e-4 * horizon_length
        sol = diffrax.diffeqsolve(
            terms=term,
            solver=solver,
            t0=t0,
            t1=t1,
            dt0=None,
            y0=y0,
            saveat=diffrax.SaveAt(steps=True),
            stepsize_controller=diffrax.StepTo(jnp.linspace(t0, t1, horizon_length + 1)),
            max_steps=horizon_length,
        )
        ys = sol.ys
    else:
        state = solver.init(term, 0, 1e-4, y0, None)

        def ode_step(t__y__state, _):
            t0, y, state = t__y__state
            t1 = t0 + 1e-4
            y, _, _, state, _ = solver.step(
                term, t0, t1, y, None, state, made_jump=False
            )
            return (t1, y, state), y

        _, ys = jax.lax.scan(ode_step, (0, y0, state), xs=None, length=horizon_length)
    return ys

fwd = eqx.filter_jit(eqx.debug.assert_max_traces(run, max_traces=2))

def speedtest(diffeqsolve, name):
    pkey, skey, rkey = jr.split(jr.key(0), 3)
    policy = eqx.nn.MLP(2, 1, 2, 3, activation=jax.nn.leaky_relu, key=pkey)
    y0 = jr.uniform(skey, minval=0.0, maxval=30.0, shape=(1,))
    ref_y = jr.uniform(rkey, minval=0.0, maxval=30.0, shape=(1,))
    horizon_length = 25

    runfwd = lambda: jax.block_until_ready(
        fwd(policy, y0, ref_y, horizon_length, diffeqsolve)
    )
    # Compile
    runfwd()
    fwd_times = timeit.repeat(runfwd, number=100, repeat=10)
    print(f"{name} fwd: {min(fwd_times)}")


speedtest(False, "scan")  # scan fwd: 0.010120792023371905
speedtest(True, "diffeqsolve")  # diffeqsolve fwd: 0.010992499999701977

Okay, what changed relative to the earlier benchmark scripts?

First of all, I've set all the parameters for diffeqsolve to follow lax.scan-like behaviour. (I've also tidied up the raw lax.scan implementation in various ways, none of which affect the performance.) After that, the magic trick is setting os.environ["EQX_ON_ERROR"] = "nan"! (Although you might want to leave this on later anyway, see later...)

What's going on is that this is disabling some of the runtime checking that Diffrax is doing (to make sure that your integration times are monotonic, for example). When benchmarking the difference between two implementations, then runtime checks are a really common source of overhead. In total in the above implementation, we have two of them here:

https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_step_size_controller/constant.py#L78

https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_step_size_controller/constant.py#L103

Now why might you want to leave these on? The answer is that these are a purely additive difference. They run once at the start of each diffeqsolve, and do not scale with the complexity of the vector field / the number of steps / etc. In total we see from the above that they're adding about 0.0002 seconds overhead to each solve. For most use-cases that's a totally acceptable overhead! In return we know that we won't silently be doing something wrong.

The only reason that the overhead grows to as much as you see before is because you're doing many very tiny solves. (If your actual use-case really is this though then indeed you may wish to disable the checks.)

The above benchmark skips benchmarking the backward passes. Diffrax always does this with recursive checkpointing. This is a pretty awesome feature that's unique to Diffrax -- it allows for backpropagation through adaptive time stepping, which may take an arbitrary number of steps (not a thing you can do with lax.scan) -- but it means that there isn't a direct apples-to-apples comparison on that one. (We could probably a custom diffeqsolve(..., adjoint=...) that implements lax.scan like behaviour when the trip count is known statically, though, if someone felt strongly enough about this.)

I hope that helps!

patrick-kidger avatar Feb 21 '25 18:02 patrick-kidger