pyPESTO icon indicating copy to clipboard operation
pyPESTO copied to clipboard

Using PyPesto with jax

Open MaAl13 opened this issue 7 months ago • 11 comments

Hello, i want to use your package in order to do parameter estimation of ODEs and later on compute confidence intervals with profile likelihood. However to me the following code is not working on a toy example. Can you maybe tell me what i am doing wrong? I want to use the scatter search since it has been shown better convergence properties tha purely local or global methods.

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
result_custom_problem = optimizer.minimize(problem=problem1)

MaAl13 avatar Jul 11 '24 14:07 MaAl13