diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Taking more than one gradient fails with default RecursiveCheckpointAdjoint

Open nwlambert opened this issue 7 months ago • 2 comments

I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.

Minimal working example (using the default ODE example from the diffrax introduction):

import jax.numpy as jnp
import numpy as np
import jax
from diffrax import diffeqsolve, ODETerm, Dopri5,  DirectAdjoint

z = 2.3
t = 1.
        
def rhot(z):
    def f(t, y, args):
        return -z*y

    term = ODETerm(f)
    solver = Dopri5()
    y0 = jnp.array([2., 3.])
    solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0) 
    #solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0, adjoint = DirectAdjoint()) #changing the adjoint fixes it
    return solution.ys[0][0]

drhozdz = jax.grad(rhot,argnums = 0)
d2rhozdz = jax.grad(drhozdz,argnums = 0)


print("expected state ", np.exp(-z*t)*2.)
print("found state ", rhot(z))


print("expected ", -2.*t*np.exp(-z*t))
print("found 1st deriative ", drhozdz(z))

print("expected 2nd ", 2.*t**2*np.exp(-z*t))
print("found 2nd derivative ", d2rhozdz(z))  #fails with default adjoint

The error returned is: "print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint ^^^^^^^^^^^ ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."

nwlambert avatar Nov 09 '23 07:11 nwlambert