pyPESTO
pyPESTO copied to clipboard
Using PyPesto with jax
Hello, i want to use your package in order to do parameter estimation of ODEs and later on compute confidence intervals with profile likelihood. However to me the following code is not working on a toy example. Can you maybe tell me what i am doing wrong? I want to use the scatter search since it has been shown better convergence properties tha purely local or global methods.
import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing
# Lotka-Volterra model
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
return jnp.stack([d_prey, d_predator])
def solve(parameters, y0, ts):
term = dfx.ODETerm(vector_field)
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=ts)
sol = dfx.diffeqsolve(
term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
adjoint=dfx.RecursiveCheckpointAdjoint(),
)
return sol.ys
# Generate synthetic data
def get_data():
y0 = jnp.array([9.0, 9.0])
true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
ts = jnp.linspace(0, 30, 20)
values = solve(true_parameters, y0, ts)
return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)
y0, ts, noisy_values = get_data()
# Define objective function
@jax.jit
def objective(parameters):
pred_values = solve(parameters, y0, ts)
return jnp.sum((noisy_values - pred_values)**2)
#objective_with_grad = jax.value_and_grad(objective)
objective = pypesto.Objective(
fun=objective,
grad=jax.grad(objective)
)
problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
result_custom_problem = optimizer.minimize(problem=problem1)