numpyro
numpyro copied to clipboard
Request for Euler Maruyama features in numpyro
Dear Numpyro developers,
Please develop Euler Maruyama features in numpyro similar to features found in PyMC.
Thanks alot.
I think this could be a nice example to have. The simulation can be found on Wikipedia. If a custom distribution is implemented, GaussianRandomWalk would be a good reference - the implementation should be similar except for the sample
method which might require jax.lax.scan
for simulation.
I'll try to write an example for this, serves me as an opportunity to learn more about SDEs. In the mean time, @GMCobraz this is how you can replicate the second model in this example using numpyro:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
from numpyro.infer import Predictive, NUTS, MCMC
numpyro.set_host_device_count(4)
RNG = random.PRNGKey(0)
def model2(num_timesteps, dt, z=None):
# SDE params
tau = numpyro.sample("tau", dist.Gamma(3, 2))
a = numpyro.sample("a", dist.Gamma(4, 4))
sigma = numpyro.sample("sigma", dist.Gamma(3, 6))
b = jnp.sqrt(dt) * sigma
# -- model params
m = numpyro.sample("m", dist.Beta(3, 3))
noise = numpyro.sample("noise", dist.LogNormal())
# -- initial values
with numpyro.plate("plate_v0", 2):
v0 = numpyro.sample("v0", dist.Normal())
def transition(carry, _):
# -- unpack state
v_t, = carry
x_t, y_t = v_t
# -- Model for observations
mu_t = m * x_t + (1 - m) * y_t
z_t = numpyro.sample("z", dist.Normal(mu_t, noise))
# -- Diff equation system
dx = tau * (x_t - x_t ** 3.0 / 3.0 + y_t)
dy = (a - x_t) / tau
dv = jnp.c_[dx, dy].squeeze()
# -- EM Update
v_next = numpyro.sample("v", dist.Normal(v_t + dv * dt, b))
# Carry over state
carry = (v_next,)
return carry, None
timesteps = jnp.arange(num_timesteps)
init = (v0,)
with numpyro.handlers.condition(data={"z": z}):
scan(transition, init, timesteps)
And this is how you can use it:
num_timesteps = 200
dt = 0.1
# --- simulate with some fixed parameters
fixed_model2 = numpyro.handlers.condition(
model2,
{
"tau": 3.0,
"a": 1.05,
"m": 0.2,
"sigma": 1e-1,
"noise": 1e-1,
"v0": jnp.array([0.0, 0.1]),
}
)
prior = Predictive(fixed_model2, num_samples=100)
prior_samples = prior(RNG, num_timesteps, dt)
# --- run mcmc
true_idx = 0
true_z = prior_samples["z"][true_idx]
mcmc = MCMC(
NUTS(model2),
num_warmup=2000,
num_samples=2000,
num_chains=4,
)
mcmc.run(RNG, num_timesteps, dt, true_z)
# mcmc.print_summary() # for diagnostics
# --- posterior checks
posterior = Predictive(model2, mcmc.get_samples())
posterior_samples = posterior(RNG, num_timesteps, dt)
# --- extending into the future
num_future = 50
estimates = mcmc.get_samples().copy()
estimates["v0"] = estimates.pop("v")[:, -1, :] # replace the `v0` with the last value of `v` and drop `v`.
predictive = Predictive(model2, posterior_samples=estimates)
predictive_samples = predictive(RNG, num_future, dt)
Note that I have changed some of the flat priors in the example. Just because I don't like flat priors.
+1 for Euler Maruyama in numpyro
In case you haven't seen this, the pymc implementation of the log_prob
is just five lines and should be the same in numpyro/jax IMO (see below). That said, pymc does not have the random sampling method implemented, and I am not sure how to approach it. I guess one would need to do something similar to the logp calculation, but backwards?
class EulerMaruyama(distribution.Continuous):
r"""
Stochastic differential equation discretized with the Euler-Maruyama method.
Parameters
----------
dt: float
time step of discretization
sde_fn: callable
function returning the drift and diffusion coefficients of SDE
sde_pars: tuple
parameters of the SDE, passed as ``*args`` to ``sde_fn``
"""
def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
super().__init__(*args, **kwds)
self.dt = dt = at.as_tensor_variable(dt)
self.sde_fn = sde_fn
self.sde_pars = sde_pars
def logp(self, x):
"""
Calculate log-probability of EulerMaruyama distribution at specified value.
Parameters
----------
x: numeric
Value for which log-probability is calculated.
Returns
-------
TensorVariable
"""
xt = x[:-1]
f, g = self.sde_fn(x[:-1], *self.sde_pars)
mu = xt + self.dt * f
sd = at.sqrt(self.dt) * g
return at.sum(Normal.dist(mu=mu, sigma=sd).logp(x[1:]))
If you want to implement it as a distribution, here is a sketch
def sample(self, sample_shape):
noises = random.normal(key, (num_steps,))
# then using scan for the simulation to collect ys where
# y_next = y_curr + drift(y_curr, t_curr) * dt + diffusion(y_curr, t_curr) * sqrt(dt) * noise_curr
For non-trivial sample_shape, we can vmap
over scan
and reshape the output. Using distribution will make the inference faster because there is no scan
in log_prob
(see GaussianRandomWalk distribution as an example).
Thanks @fehiepsi , this is interesting. Though, admittedly, I am a little confused. IMO, the sample method you drafted is very similar to the log_prob calculation in pymc's EulerMaruyama logp method, I quoted above.
Pymc has this developer guide: https://docs.pymc.io/en/v3/developer_guide.html, which says that all their distribution classes have the logp method, which returns log probability, used by all inference methods, and the .random method, used to simulate data for posterior predictive checks.
Do .log_prob and .sample methods have the same intent in numpyro, as .logp and .random in pymc? Or is it a different design
Also, I'm not sure we need scan at all, as the calculation in the pymc example appears vectorized. But it is hard for me to speculate without understanding how the distribution classes are actually intended to work (or how scan works, for that matter) lol
Yes, they serve the same purpose. You can take a look at GaussianRandomWalk distribution where sample
method uses sequential ops (more precisely, the cumsum
ops) and log_prob
is vectorized. Simulation in time is sequentially in its nature. The reason that we can vectorize log_prob
is that we already know the value of all steps. By scan
, you can just think it is a for loop, where the next value depends on the current value. (I used similar notations like drift
diffusion
and the formula just to make things easier to follow - because the request is to have similar features to pymc3 - The main reference is from Wikipedia)
Maybe we can have an API (which is similar to jax odeint
) like
class EulerMaruyama(...):
arg_constraints = {"t": ordered_vector}
def __init__(self, sde_fn, init_dist, t, validate_args=None):
# we use t rather than (dt, num_steps) to make the api simpler and more general,
# `init_dist` will define prior for the init value,
# event dimensions will be decided by `init_dist`
Yes, they serve the same purpose. You can take a look at GaussianRandomWalk distribution where
sample
method uses sequential ops (more precisely, thecumsum
ops) andlog_prob
is vectorized. Simulation in time is sequentially in its nature. The reason that we can vectorizelog_prob
is that we already know the value of all steps. Byscan
, you can just think it is a for loop, where the next value depends on the current value. (I used similar notations likedrift
diffusion
and the formula just to make things easier to follow - because the request is to have similar features to pymc3 - The main reference is from Wikipedia)Maybe we can have an API (which is similar to jax
odeint
) likeclass EulerMaruyama(...): arg_constraints = {"t": ordered_vector} def __init__(self, sde_fn, init_dist, t, validate_args=None): # we use t rather than (dt, num_steps) to make the api simpler and more general, # `init_dist` will define prior for the init value, # event dimensions will be decided by `init_dist`
I like it!
I gave it a few naive attempts, trying to subclass the GRW or the Distribution base class, but it appears to be slightly more involved than that lol
I want to try to contribute. I have a question.
# event dimensions will be decided by init_dist
How is it possible ? init_dist decides batch dimensions, right ?
Assuming that init_dist
has shape batch_shape + event_shape
, I guess the Euler Maruyama distribution would have shape batch_shape + new_event_shape
where new_event_shape = (t,) + event_shape
. Maybe we can assume init_dist.batch_shape == ()
for simplicity. What do you think?
@fehiepsi
Thanks. But still confusing.
I made draft PR.
I assume init_dist
has shape batch_shape
and event_shape
is defined by t.
Could you comment on PR ?
Resolved in #1504 by @yayami3 💯 !