diffrax
diffrax copied to clipboard
Problem when using Diffrax for numpyro
Hello , i want to use Diffrax for Bayesian inference of parameters in numpyro. However, as soon as i change the StepsizeControler from ConstantStepsie to DPIController i get an error. Changing the max_steps to really big numbers and also using an implicit solver doesn't help. Can you maybe tell me what the problem is? The code is the following
import sys
# from pathlib import Path
from jax.experimental.ode import odeint
# import arviz as az
import dill
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from diffrax import PIDController, Dopri5, ODETerm, SaveAt, diffeqsolve, Kvaerno5
from jax import random
from numpyro.infer import MCMC, NUTS # , Predictive
from sklearn.preprocessing import LabelEncoder
# from fem_cycle_model.fetch_params import fetch_params
# from fem_cycle_model.main import run_model
numpyro.enable_x64()
pd.options.mode.chained_assignment = None # default='warn'
assert numpyro.__version__.startswith("0.11.0")
# Select the number of cores that numpyro will use
numpyro.set_host_device_count(1)
def test_ode(t, z, theta):
"""
Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
describes the interaction of two species.
"""
u = z[0]
v = z[1]
alpha, beta, gamma, delta = (
theta[0],
theta[1],
theta[2],
theta[3],
)
#print(theta)
du_dt = (alpha - beta * v) * u
dv_dt = (-gamma + delta * u) * v
return jnp.stack([du_dt, dv_dt])
# Create a linspace of time points 0, 0.2, 0.4, ...
def sim(params):
term = ODETerm(test_ode)
solver = Dopri5()
saveat = SaveAt(ts=np.linspace(0, 10, 50))
stepsize_controller = PIDController(rtol = 1e-3, atol = 1e-6, step_ts = np.linspace(0, 10, 50))
sol = diffeqsolve(
term,
solver,
t0=0,
t1=10,
dt0=0.1,
y0=jnp.array([1, 0.2]),
saveat=saveat,
max_steps = 20000,
stepsize_controller=stepsize_controller,
args=params,
)
#sol = odeint(test_ode, jnp.array([1,0.2]), jnp.linspace(0, 10, 50), jnp.array(params), rtol=1e-6, atol=1e-5, mxstep=1000)
return sol.ys[:, 0], sol.ys[:, 1]
def run_model_data(num_patients, params):
y1_total = jnp.array([])
y2_total = jnp.array([])
for i in range(num_patients):
params_i = jnp.array(params) * (i + 1)
y1, y2 = sim(params_i)
y1_total = jnp.append(y1_total, y1)
y2_total = jnp.append(y2_total, y2)
return y1_total, y2_total
patient_y1, patient_y2 = run_model_data(1, [2, 3, 4, 5])
def run_model_all(params):
y1, y2 = sim(params)
return y1, y2
def model(y1=None, y2=None):
"""
μ_foll_alpha = numpyro.sample("μ_foll_alpha", dist.Uniform(1.0, 7.0))
σ_foll_alpha = numpyro.sample("σ_foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(3.0, 2.0)))
μ_foll_beta = numpyro.sample("μ_foll_beta", dist.Uniform(2.0, 10.0))
σ_foll_beta = numpyro.sample("σ_foll_beta", dist.LeftTruncatedDistribution(dist.Normal(4.0, 2.0)))
μ_foll_gamma = numpyro.sample("μ_foll_gamma", dist.Uniform(3.0, 13.0))
σ_foll_gamma = numpyro.sample("σ_foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(5.0, 2.0)))
μ_foll_delta = numpyro.sample("μ_foll_delta", dist.Uniform(4.0, 16.0))
σ_foll_delta = numpyro.sample("σ_foll_delta", dist.LeftTruncatedDistribution(dist.Normal(7.0, 2.0)))
σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))
#with numpyro.plate("plate_i", N_PATIENTS):
foll_alpha = numpyro.sample("foll_alpha", dist.Normal(μ_foll_alpha, σ_foll_alpha))
foll_beta = numpyro.sample("foll_beta", dist.Normal(μ_foll_beta, σ_foll_beta))
foll_gamma = numpyro.sample("foll_gamma", dist.Normal(μ_foll_gamma, σ_foll_gamma))
foll_delta = numpyro.sample("foll_delta", dist.Normal(μ_foll_delta, σ_foll_delta))
"""
σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))
foll_alpha = numpyro.sample("foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(2, 1)))
foll_beta = numpyro.sample("foll_beta", dist.LeftTruncatedDistribution(dist.Normal(3, 1)))
foll_gamma = numpyro.sample("foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(4, 1)))
foll_delta = numpyro.sample("foll_delta", dist.LeftTruncatedDistribution(dist.Normal(5, 1)))
y1_est, y2_est= run_model_all(
[
foll_alpha,
foll_beta,
foll_gamma,
foll_delta,
]
)
with numpyro.plate("likelihood", len(y1)):
numpyro.sample("obs_dom", dist.Normal(y1_est, σ), obs=y1)
numpyro.sample("obs_non_dom", dist.Normal(y2_est, σ), obs=y2)
data_dict = dict(
y1=patient_y1,
y2 = patient_y2,
)
# Specify the number of chains in the Markov Chain Monte Carlo. Typically set to the nmber of cores in the computer
mcmc_kwargs = dict(num_samples=2000, num_warmup=2000, num_chains=4)
# Select a random key and split it into different parts. This guarantees that we get the same result each time and
# will lead to reproducable results. For more see:
# https://ericmjl.github.io/dl-workshop/02-jax-idioms/03-deterministic-randomness.html
rng_key = random.PRNGKey(12)
seed1, seed2, seed3, seed4, seed5 = random.split(rng_key, 5)
inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample(), dense_mass=True), **mcmc_kwargs)
inference_mcmc.run(seed1, **data_dict)
print(inference_mcmc.print_summary())
The error is
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py:150: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using numpyro.set_host_device_count(4)
at the beginning of your program. You can double-check how many devices are available in your system using jax.local_device_count()
.
inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample(), dense_mass=True), **mcmc_kwargs)
warmup: 0%| | 1/4000 [00:19<21:38:58, 19.49s/it, 1 steps of size 2.34e+00. acc. prob=0.00]
Traceback (most recent call last):
File "/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py", line 151, in max_steps
.
At:
/home/malmansto/anaconda3/lib/python3.9/site-packages/equinox/internal/errors.py(17): raises
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(142): _flat_callback
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(42): pure_callback_impl
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(105): _callback
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/mlir.py(1798): _wrapped_callback
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py(2113): call
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py(565): cache_miss
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py(358): fori_collect
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(404): _single_chain_mcmc
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(160): _laxmap
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(598): run
/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py(151):
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py", line 151, in max_steps
.
At:
/home/malmansto/anaconda3/lib/python3.9/site-packages/equinox/internal/errors.py(17): raises
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(142): _flat_callback
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(42): pure_callback_impl
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(105): _callback
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/mlir.py(1798): _wrapped_callback
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py(2113): call
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py(565): cache_miss
/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py(358): fori_collect
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(404): _single_chain_mcmc
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(160): _laxmap
/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(598): run
/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py(151):
If I had to guess what's going on here: the parameters are being suggested with the wrong sign, so that the Lotka-Volterra equations blow up in finite time.
I've not tried running you example as it is a little large. See if you can try reducing it to a MWE, in particular without numpyro. You should be able to simply check what parameters numpyro is suggesting, and then evaluate Diffrax on those parameters directly.
General debugging tips for this kind of thing by the way:
- Put
jax.debug.{print, breakpoint}
inside your vector field, to see what times and values it is being evaluated at. Do the times converge towards a single time (i.e. the point at which an equation blows up in finite time)? So the values explode towards very large numbers, very small numbers, orinf
, or-inf
, ornan
? - Pass
diffeqsolve(..., throw=False, saveat=SaveAt(steps=True))
and see what times it is being evaluated at. You will get the times and values at which the solver placed steps. These are stored in arrays of lengthmax_steps
: the first part of this array will be the times you evaluate at; after that it will be padded withinf
.