diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Inspection of gradient calculation

Open moesphere opened this issue 2 years ago • 3 comments

Hello!

Is there a tutorial on how to inspect gradient/hessian calculation when using diffrax?

I am trying to calculate gradients and hessians, but the returned values are always nans, and I do not know where to start "debugging".

moesphere avatar Jul 21 '23 11:07 moesphere

So there's nothing special about Diffrax -- I'd recommend debugging NaNs in the same way as any JAX program.

First of all, one very common source of NaNs during autodiff is a missing double where trick. (Needing this trick is actually a general fact of autodifferentiation systems! Nothing JAX-specific here.)

Second, another common source is when doing a sqrt or log of a negative number. If your vector field includes either of these operations, then consider checking whether their input is negative. (Note that Diffrax reserves the right to query your vector field with any values for t and y, even those outside the region in which the ODE is solved -- some solvers will make queries outside those ranges. So your vector field must be robust to such queries.) You can easily check this with jax.debug.print.

If it's not any of those common errors, then your main tools to track this down are:

  • jax.debug.print to print intermediate quantities.
  • jax.debug.breakpoint to put a breakpoint in the middle of your computation. (Make sure to pass the -s flag if used in conjunction with pytest.)
  • Set the environment variable JAX_DEBUG_NANS=1 to interrupt the computation when a NaN is detected. This works best if the NaN occurs on the forward pass and outside of any loops. (If the NaN occurs inside a loop then adding JAX_DISABLE_JIT=1 can also help. And once you become more advanced, then JAX_TRACEBACK_FILTERING=off can also help, as it allows you to inspect JAX internals.)
  • If your gradients only arise on the backward pass, but not the forward pass, then try equinox.internal.debug_backward_nan.

Generally speaking, NaN issues aren't usually too tricky to track down. Just bisect through your code until you find the operation that's producing them.

Moreover I'd encourage you to be willing to place these checks inside your copy of Diffrax's code. You can get the install location via import diffrax; print(diffrax.__file__). (Placing breakpoints within the library is often a quicker way of debugging, as you have finer-grained control over your bisection search.)

patrick-kidger avatar Jul 21 '23 15:07 patrick-kidger

Thank you for the fast and informative answer! I will look into it.

I have used the same model equations with cvodes from sundials (via casadi). The results from the gradient calclations, i.e., the parameter sensitivities, were successful and physically feasible. Does this information favor a potential debugging strategy from the ones you mentioned above?

moesphere avatar Jul 21 '23 16:07 moesphere

That probably doesn't help, I'm afraid!

patrick-kidger avatar Jul 21 '23 16:07 patrick-kidger