Inspection of gradient calculation
Hello!
Is there a tutorial on how to inspect gradient/hessian calculation when using diffrax?
I am trying to calculate gradients and hessians, but the returned values are always nans, and I do not know where to start "debugging".
So there's nothing special about Diffrax -- I'd recommend debugging NaNs in the same way as any JAX program.
First of all, one very common source of NaNs during autodiff is a missing double where trick. (Needing this trick is actually a general fact of autodifferentiation systems! Nothing JAX-specific here.)
Second, another common source is when doing a sqrt or log of a negative number. If your vector field includes either of these operations, then consider checking whether their input is negative. (Note that Diffrax reserves the right to query your vector field with any values for t and y, even those outside the region in which the ODE is solved -- some solvers will make queries outside those ranges. So your vector field must be robust to such queries.) You can easily check this with jax.debug.print.
If it's not any of those common errors, then your main tools to track this down are:
jax.debug.printto print intermediate quantities.jax.debug.breakpointto put a breakpoint in the middle of your computation. (Make sure to pass the-sflag if used in conjunction with pytest.)- Set the environment variable
JAX_DEBUG_NANS=1to interrupt the computation when a NaN is detected. This works best if the NaN occurs on the forward pass and outside of any loops. (If the NaN occurs inside a loop then addingJAX_DISABLE_JIT=1can also help. And once you become more advanced, thenJAX_TRACEBACK_FILTERING=offcan also help, as it allows you to inspect JAX internals.) - If your gradients only arise on the backward pass, but not the forward pass, then try
equinox.internal.debug_backward_nan.
Generally speaking, NaN issues aren't usually too tricky to track down. Just bisect through your code until you find the operation that's producing them.
Moreover I'd encourage you to be willing to place these checks inside your copy of Diffrax's code. You can get the install location via import diffrax; print(diffrax.__file__). (Placing breakpoints within the library is often a quicker way of debugging, as you have finer-grained control over your bisection search.)
Thank you for the fast and informative answer! I will look into it.
I have used the same model equations with cvodes from sundials (via casadi). The results from the gradient calclations, i.e., the parameter sensitivities, were successful and physically feasible. Does this information favor a potential debugging strategy from the ones you mentioned above?
That probably doesn't help, I'm afraid!