numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Request for Euler Maruyama features in numpyro

Open GMCobraz opened this issue 3 years ago • 8 comments

Dear Numpyro developers,

Please develop Euler Maruyama features in numpyro similar to features found in PyMC.

Thanks alot.

GMCobraz avatar Oct 30 '21 05:10 GMCobraz

I think this could be a nice example to have. The simulation can be found on Wikipedia. If a custom distribution is implemented, GaussianRandomWalk would be a good reference - the implementation should be similar except for the sample method which might require jax.lax.scan for simulation.

fehiepsi avatar Oct 31 '21 00:10 fehiepsi

I'll try to write an example for this, serves me as an opportunity to learn more about SDEs. In the mean time, @GMCobraz this is how you can replicate the second model in this example using numpyro:

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
from numpyro.infer import Predictive, NUTS, MCMC

numpyro.set_host_device_count(4)
RNG = random.PRNGKey(0)

def model2(num_timesteps, dt, z=None):
    # SDE params
    tau = numpyro.sample("tau", dist.Gamma(3, 2))
    a = numpyro.sample("a", dist.Gamma(4, 4))
    sigma = numpyro.sample("sigma", dist.Gamma(3, 6))
    b = jnp.sqrt(dt) * sigma
    
    # -- model params
    m = numpyro.sample("m", dist.Beta(3, 3))
    noise = numpyro.sample("noise", dist.LogNormal())
    
    # -- initial values
    with numpyro.plate("plate_v0", 2):
        v0 = numpyro.sample("v0", dist.Normal())
    
    def transition(carry, _):
        # -- unpack state
        v_t, = carry
        x_t, y_t = v_t
        
        # -- Model for observations
        mu_t = m * x_t + (1 - m) * y_t
        z_t = numpyro.sample("z", dist.Normal(mu_t, noise))
        
        # -- Diff equation system
        dx = tau * (x_t - x_t ** 3.0 / 3.0 + y_t)
        dy = (a - x_t) / tau
        dv = jnp.c_[dx, dy].squeeze()
        
        # -- EM Update
        v_next = numpyro.sample("v", dist.Normal(v_t + dv * dt, b))
        
        # Carry over state
        carry = (v_next,)
        return carry, None
    
    timesteps = jnp.arange(num_timesteps)
    init = (v0,)
    with numpyro.handlers.condition(data={"z": z}):
        scan(transition, init, timesteps)

And this is how you can use it:

num_timesteps = 200
dt = 0.1

# --- simulate with some fixed parameters
fixed_model2 = numpyro.handlers.condition(
    model2,
    {
        "tau": 3.0,
        "a": 1.05,
        "m": 0.2,
        "sigma": 1e-1,
        "noise": 1e-1,
        "v0": jnp.array([0.0, 0.1]),
    }
)
prior = Predictive(fixed_model2, num_samples=100)
prior_samples = prior(RNG, num_timesteps, dt)

# --- run mcmc
true_idx = 0
true_z = prior_samples["z"][true_idx]

mcmc = MCMC(
    NUTS(model2),
    num_warmup=2000,
    num_samples=2000,
    num_chains=4,
)

mcmc.run(RNG, num_timesteps, dt, true_z)
# mcmc.print_summary() # for diagnostics

# --- posterior checks
posterior = Predictive(model2, mcmc.get_samples())
posterior_samples = posterior(RNG, num_timesteps, dt)

# --- extending into the future
num_future = 50
estimates = mcmc.get_samples().copy()
estimates["v0"] = estimates.pop("v")[:, -1, :] # replace the `v0` with the last value of `v` and drop `v`.
predictive = Predictive(model2, posterior_samples=estimates)
predictive_samples = predictive(RNG, num_future, dt)

Note that I have changed some of the flat priors in the example. Just because I don't like flat priors.

omarfsosa avatar Jan 06 '22 18:01 omarfsosa

+1 for Euler Maruyama in numpyro

sokol11 avatar Jan 07 '22 07:01 sokol11

In case you haven't seen this, the pymc implementation of the log_prob is just five lines and should be the same in numpyro/jax IMO (see below). That said, pymc does not have the random sampling method implemented, and I am not sure how to approach it. I guess one would need to do something similar to the logp calculation, but backwards?

class EulerMaruyama(distribution.Continuous):
    r"""
    Stochastic differential equation discretized with the Euler-Maruyama method.
    Parameters
    ----------
    dt: float
        time step of discretization
    sde_fn: callable
        function returning the drift and diffusion coefficients of SDE
    sde_pars: tuple
        parameters of the SDE, passed as ``*args`` to ``sde_fn``
    """

    def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
        super().__init__(*args, **kwds)
        self.dt = dt = at.as_tensor_variable(dt)
        self.sde_fn = sde_fn
        self.sde_pars = sde_pars

    def logp(self, x):
        """
        Calculate log-probability of EulerMaruyama distribution at specified value.
        Parameters
        ----------
        x: numeric
            Value for which log-probability is calculated.
        Returns
        -------
        TensorVariable
        """
        xt = x[:-1]
        f, g = self.sde_fn(x[:-1], *self.sde_pars)
        mu = xt + self.dt * f
        sd = at.sqrt(self.dt) * g
        return at.sum(Normal.dist(mu=mu, sigma=sd).logp(x[1:]))

sokol11 avatar Jan 10 '22 09:01 sokol11

If you want to implement it as a distribution, here is a sketch

def sample(self, sample_shape):
    noises = random.normal(key, (num_steps,))
    # then using scan for the simulation to collect ys where
    # y_next = y_curr + drift(y_curr, t_curr) * dt + diffusion(y_curr, t_curr) * sqrt(dt) * noise_curr

For non-trivial sample_shape, we can vmap over scan and reshape the output. Using distribution will make the inference faster because there is no scan in log_prob (see GaussianRandomWalk distribution as an example).

fehiepsi avatar Jan 11 '22 00:01 fehiepsi

Thanks @fehiepsi , this is interesting. Though, admittedly, I am a little confused. IMO, the sample method you drafted is very similar to the log_prob calculation in pymc's EulerMaruyama logp method, I quoted above.

Pymc has this developer guide: https://docs.pymc.io/en/v3/developer_guide.html, which says that all their distribution classes have the logp method, which returns log probability, used by all inference methods, and the .random method, used to simulate data for posterior predictive checks.

Do .log_prob and .sample methods have the same intent in numpyro, as .logp and .random in pymc? Or is it a different design

Also, I'm not sure we need scan at all, as the calculation in the pymc example appears vectorized. But it is hard for me to speculate without understanding how the distribution classes are actually intended to work (or how scan works, for that matter) lol

sokol11 avatar Jan 11 '22 08:01 sokol11

Yes, they serve the same purpose. You can take a look at GaussianRandomWalk distribution where sample method uses sequential ops (more precisely, the cumsum ops) and log_prob is vectorized. Simulation in time is sequentially in its nature. The reason that we can vectorize log_prob is that we already know the value of all steps. By scan, you can just think it is a for loop, where the next value depends on the current value. (I used similar notations like drift diffusion and the formula just to make things easier to follow - because the request is to have similar features to pymc3 - The main reference is from Wikipedia)

Maybe we can have an API (which is similar to jax odeint) like

class EulerMaruyama(...):
    arg_constraints = {"t": ordered_vector}
    def __init__(self, sde_fn, init_dist, t, validate_args=None):
        # we use t rather than (dt, num_steps) to make the api simpler and more general,
        # `init_dist` will define prior for the init value,
        # event dimensions will be decided by `init_dist`

fehiepsi avatar Jan 11 '22 12:01 fehiepsi

Yes, they serve the same purpose. You can take a look at GaussianRandomWalk distribution where sample method uses sequential ops (more precisely, the cumsum ops) and log_prob is vectorized. Simulation in time is sequentially in its nature. The reason that we can vectorize log_prob is that we already know the value of all steps. By scan, you can just think it is a for loop, where the next value depends on the current value. (I used similar notations like drift diffusion and the formula just to make things easier to follow - because the request is to have similar features to pymc3 - The main reference is from Wikipedia)

Maybe we can have an API (which is similar to jax odeint) like

class EulerMaruyama(...):
    arg_constraints = {"t": ordered_vector}
    def __init__(self, sde_fn, init_dist, t, validate_args=None):
        # we use t rather than (dt, num_steps) to make the api simpler and more general,
        # `init_dist` will define prior for the init value,
        # event dimensions will be decided by `init_dist`

I like it!

I gave it a few naive attempts, trying to subclass the GRW or the Distribution base class, but it appears to be slightly more involved than that lol

sokol11 avatar Jan 12 '22 13:01 sokol11

I want to try to contribute. I have a question.

# event dimensions will be decided by init_dist

How is it possible ? init_dist decides batch dimensions, right ?

yayami3 avatar Nov 22 '22 15:11 yayami3

Assuming that init_dist has shape batch_shape + event_shape, I guess the Euler Maruyama distribution would have shape batch_shape + new_event_shape where new_event_shape = (t,) + event_shape. Maybe we can assume init_dist.batch_shape == () for simplicity. What do you think?

fehiepsi avatar Nov 23 '22 21:11 fehiepsi

@fehiepsi Thanks. But still confusing. I made draft PR. I assume init_dist has shape batch_shape and event_shape is defined by t.

Could you comment on PR ?

yayami3 avatar Nov 26 '22 11:11 yayami3

Resolved in #1504 by @yayami3 💯 !

fehiepsi avatar Dec 16 '22 23:12 fehiepsi