diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Large speed difference between forward and backward passes

Open deasmhumhna opened this issue 2 years ago • 2 comments

I just discovered the Diffrax package and it's great! However, I'm encountering an issue where the gradient evaluation is 40-80 times slower than the forward pass for my particular network (a potential network whose gradient is the vector field). The difference factor also seems to grow with the number of steps used in the integration loop. When the model is still near the identity map and the number of steps is low, the forward pass takes ~50 ms and the backward + update pass takes ~2 s (x40). Later on, the forward pass takes ~500 ms while the backward pass + update takes ~40 s (x80). These times are on a single GPU. I expect the gradient to take longer than the forward evaluation but 80 times slower seems extreme, even with the extra compute from checkpointing. For reference, evaluating the vector field takes ~2 ms while calculating the gradient on the norm of the vector field takes ~12 ms. I'll create a minimal working example shortly. I am using Flax so I wonder if that has something to do with it.

EDIT: Example in Colab: single conditioned 3x3 convolution followed by logsumexp pooling and some final ResNet layers. Similar behavior reproduced on CPU. Curious if it has something to do with the convolutions.

deasmhumhna avatar Oct 21 '22 18:10 deasmhumhna

FWIW a surprisingly large speed difference is often the case here, due to the memory overhead of looking up the stored residuals. Despite this, I agree that what you've got here looks like a strangely large speed difference.

Thank you for the clean repro. Your code looks good; I don't see any obvious mistakes.

I don't have a simple answer for you. Here's some things you can try, and see what happens:

  1. One point of funny business that does jump out at me: your first runs only take 3 steps inside the differential equation solver. Given that you start with with a step of size 0.01, what this almost certainly means is that this is growing by the maximum factor of 10 at every step, so that the step locations are 0, 0.01, 10*0.01+0.01, min(1, 10*10*0.01+10*0.01+0.01), i.e. 0, 0.01, 0.11, 1. I've not dug into what's going on here but it seems like your tolerances are too loose, or your vector field too close to zero, for you to meaningfully be said to be solving a differential equation. Especially as you're using such a simple solver (Heun). (Although I'm not convinced that this really explains anything.)

  2. It's very possible that the use of higher-order autodiff is causing some kind of misbehaviour. JAX promises to be functionally correct here, but I think compiling to efficient code doesn't always happen as well as we'd like. You might be able to narrow a repro down to something in this space, e.g. by trying a diffeqsolve without higher-order autodiff, or by trying higher-order autodiff with a diffeqsolve.

  3. Possibly try using jax.experimental.ode.odeint or torchdiffeq as well. In practice Diffrax dominates both on every problem I've tried it on, but it may be helpful to try these even so.

And some other thoughts in passing:

  1. I am currently in the process of writing an improved core integration loop, that doesn't use checkpointing. This will probably speed up times somewhat.

  2. You can compute the gradient wrt t more cheaply by using forward-mode autodifferentiation:

fn = lambda t: state.apply_fn(params, t, y)
jax.jvp(fn, (t,), (jnp.ones_like(t),))
  1. I suspect the choice of Flax vs Equinox doesn't affect things here, but you could always try an Equinox implementation to be sure.

patrick-kidger avatar Oct 25 '22 07:10 patrick-kidger

Thanks for the quick response. I was using jax.experimental.ode.odeint before I found Diffrax and I can't remember the exact difference between train_step and eval_step. I switched because odeint was taking forever to compile (my actual model is much larger/deeper). I updated the example to include a comparison with odeint and it seems to be experiencing a similar slowdown so the issue might be compiler inefficiency on higher-order differentiation. I'll try some comparisons with and without it.

edit: The memory overhead could also be a part of it as well as the time difference decreased when I decreased the image size. With a 56x56 image, the time factor for comparable potential and vector field networks are both ~20x. Which seems a little bit more reasonable. The potential network is more memory efficient as the vector field network caused an OOM exception on the original 224x224 image. But that efficiency may be coming at the cost of a longer backward pass compute time. This could suggest the issue stems from the interaction of higher-order derivatives and checkpointing. I'll have to test this.

deasmhumhna avatar Oct 26 '22 01:10 deasmhumhna