diffrax
diffrax copied to clipboard
Taking more than one gradient fails with default RecursiveCheckpointAdjoint
I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.
Minimal working example (using the default ODE example from the diffrax introduction):
import jax.numpy as jnp
import numpy as np
import jax
from diffrax import diffeqsolve, ODETerm, Dopri5, DirectAdjoint
z = 2.3
t = 1.
def rhot(z):
def f(t, y, args):
return -z*y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0)
#solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0, adjoint = DirectAdjoint()) #changing the adjoint fixes it
return solution.ys[0][0]
drhozdz = jax.grad(rhot,argnums = 0)
d2rhozdz = jax.grad(drhozdz,argnums = 0)
print("expected state ", np.exp(-z*t)*2.)
print("found state ", rhot(z))
print("expected ", -2.*t*np.exp(-z*t))
print("found 1st deriative ", drhozdz(z))
print("expected 2nd ", 2.*t**2*np.exp(-z*t))
print("found 2nd derivative ", d2rhozdz(z)) #fails with default adjoint
The error returned is: "print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint ^^^^^^^^^^^ ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."