sunode icon indicating copy to clipboard operation
sunode copied to clipboard

Correction for Adjoint Gradient Calculations

Open aarcher07 opened this issue 3 years ago • 2 comments

When computing the gradient with respect to initial conditions, I find that the adjoint and fwd gradient computations differ slightly. Following equation 14 of CVODES manual, the discrepancy is in the adjoint equation and can be corrected by adding the constant, -np.matmul(sens0, lambda_out - grads[0, :]), to grad_out from solve_backward. sens0 is the initial sensitivities, lambda_out is the adjoint variables at time 0 and grads[0, :]) is the derivative of the likelihood with respect to the state variables at time 0.

I have to adjust lambda at 0 by -grads[0, :] because line 691 of sunode/solver.py appears to not loop over the initial time point of grads.

Does SolveODEAdjointBackward in sunode/sunode/wrappers/as_aesara.py implement a similar correction constant when computing the gradient wrt initial conditions?

I attached some code below as an example. It is based on the example script of the readMe. It computes the gradient of the sum of Hares and Lynx at time 0, 5 and 10 wrt alpha, beta and hares0 where hares0 is log10(Hares(0)).

import numpy as np
import sunode
import sunode.wrappers.as_aesara
import pymc as pm
import matplotlib.pyplot as plt
lib = sunode._cvodes.lib


def lotka_volterra(t, y, p):
    """Right hand side of Lotka-Volterra equation.

    All inputs are dataclasses of sympy variables, or in the case
    of non-scalar variables numpy arrays of sympy variables.
    """
    return {
        'hares': p.alpha * y.hares - p.beta * y.lynx * y.hares,
        'lynx': p.delta * y.hares * y.lynx - p.gamma * y.lynx,
    }

# initialize problem
problem = sunode.symode.SympyProblem(
    params={
        # We need to specify the shape of each parameter.
        # Any empty tuple corresponds to a scalar value.
        'alpha': (),
        'beta': (),
        'gamma': (),
        'delta': (),
        'hares0': ()
    },
    states={
        # The same for all state variables
        'hares': (),
        'lynx': (),
    },
    rhs_sympy=lotka_volterra,
    derivative_params=[
        # We need to specify with respect to which variables
        # gradients should be computed.
        ('alpha',),
        ('beta',),
        ('hares0',),
    ],
)

tvals = np.linspace(0, 10, 3)

y0 = np.zeros((), dtype=problem.state_dtype)
y0['hares'] = 1e0
y0['lynx'] = 0.1
params_dict = {
    'alpha': 0.1,
    'beta': 0.2,
    'gamma': 0.3,
    'delta': 0.4,
    'hares0': 1e0
}


sens0 = np.zeros((3, 2))
sens0[2,0] = np.log(10)*1e0

solver = sunode.solver.Solver(problem, solver='BDF', sens_mode='simultaneous')
yout, sens_out = solver.make_output_buffers(tvals)


# gradient via fwd senstivity
solver.set_params_dict(params_dict)
output = solver.make_output_buffers(tvals)
solver.solve(t0=0, tvals=tvals, y0=y0, y_out=yout, sens0=sens0, sens_out=sens_out)

grad_out_fwd = [ sens_out[:,j,:].sum() for j in range(3)]
print(grad_out_fwd)

# gradient via adj senstivity
solver = sunode.solver.AdjointSolver(problem, solver='BDF')
solver.set_params_dict({
    'alpha': 0.1,
    'beta': 0.2,
    'gamma': 0.3,
    'delta': 0.4,
    'hares0': 1e0
})
tvals_expanded = np.linspace(0, 10, 21)
yout, grad_out, lambda_out = solver.make_output_buffers(tvals_expanded)
lib.CVodeSetMaxNumSteps(solver._ode, 10000)
solver.solve_forward(t0=0, tvals=tvals, y0=y0, y_out=yout)
grads = np.zeros_like(yout)
grads[::10,:] = 1
solver.solve_backward(t0=tvals_expanded[-1], tend=tvals_expanded[0], tvals=tvals_expanded[1:-1],
                      grads=grads, grad_out=grad_out, lamda_out=lambda_out)
grad_out_adj = -np.matmul(sens0, lambda_out  -grads[0, :]) + grad_out
print(grad_out_adj)

aarcher07 avatar Jul 12 '22 22:07 aarcher07

Thanks for opening such a detailed issue! I edited the formatting in your comment and added a link to that line.


I'm not familiar with the implementation, but is this a bug due to t_intervals and grads having different lengths and the for iterator never reaches the None element in the reversed(grads)? Because inside that loop there's this if grad is not None: which appears to have been written for this t=0 element..

michaelosthege avatar Jul 15 '22 16:07 michaelosthege

Great! Thank you for editing my post.

There is also another issue. At sufficiently small time evaluations, the gradient computations via the adjoint equations are inaccurate when compared to those of forward sensitivities.

Following the example above, if I evaluate the adjoint equation at time = 0, 5, 10 and grads = np.ones_like(yout) then I get that

  • grad_out_fwd = [29.633367875233063, -8.63361922455043, 10.2485995824757]
  • grad_out_adj = [ 8.08665182 -1.22041063 8.53419337].

However as in my original post, if I evaluate the adjoint equations at np.linspace(0, 10, 21), which includes time = 0, 5, 10, and zeros-pad grads at the time not equal to 0, 5, 10, then I get

  • grad_out_fwd = [29.633367875233063, -8.63361922455043, 10.2485995824757]
  • grad_out_adj = [27.71999675 -7.42137746 10.507334 ].

Thank you for looking to these issues!

aarcher07 avatar Jul 15 '22 18:07 aarcher07

@aarcher07 Thank you for reporting this, and sorry for the very late reply...

I think the problem you are seeing is due to a small mistake in the arguments to solve_backward. If I replace it by this, I get the same results as the forward solver:


# Instead of this
#solver.solve_backward(t0=tvals_expanded[-1], tend=tvals_expanded[0], tvals=tvals_expanded[1:-1],
#                      grads=grads, grad_out=grad_out, lamda_out=lambda_out)

# It should be this
solver.solve_backward(
    t0=tvals_expanded[-1],
    tend=tvals_expanded[0],
    tvals=tvals_expanded,
    grads=grads,
    grad_out=grad_out,
    lamda_out=lambda_out
)

grad_out_adj = -sens0 @ lambda_out + grad_out
print(grad_out_adj)

# Output

# from forward
# [29.633367875233063, -8.63361922455043, 10.2485995824757]

# from adjoint
# [29.63336772 -8.63361915 10.24859955]

The problem is that by passing in tvals=tvals_expanded[1:-1] we actually don't use the first two entries of grads, and the time points for those gradients don't match the correct tvals anymore.

aseyboldt avatar Nov 26 '22 04:11 aseyboldt

I'm closing this because I think it was a problem in the example code, but feel free to reopen or comment if you don't agree or have questions.

aseyboldt avatar Nov 26 '22 19:11 aseyboldt