Problem conditioning on `vmap` Poisson random variables
Hi everyone,
I'm finding oryx a really clean approach to implementing a PPL. However, I'm confused about conditional sampling.
Poisson process with exp random walk intensity
As an attempt to get into the structure of oryx I'm trying to sample from a probabilistic program which represents:
- A hierarchical random walk (that is a random walk where the parameters are themselves random variables)
- A further
Exptransform on the random walk represents the intensity of a Poisson process. This is observed. - Inference done with
NUTSfromblackjax
Code
Dependencies
import jax
# jax.config.update("jax_enable_x64", True)
import oryx.core.ppl as ppl
import oryx.bijectors as bijectors
import oryx.distributions as tfd
import blackjax
import jax.numpy as jnp
import jax.random as random
from functools import partial
import matplotlib.pyplot as plt
Prob Program
Note that I've implemented the link as a vmap over intensity representing conditional independence of observations.
@partial(jax.jit, static_argnames=["n"])
def hierarchical_random_walk_dist(n, init, step_scale):
rw_transformation = bijectors.Chain([bijectors.Shift(init), bijectors.Scale(step_scale), bijectors.Cumsum()])
return tfd.TransformedDistribution(tfd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n)), rw_transformation)
def poisson_process(key, n, init_prior_loc, init_prior_scale, step_scale_prior):
key_poi, key_intensity, key_init, key_step = random.split(key, 4)
init = ppl.random_variable(tfd.Normal(init_prior_loc, init_prior_scale), name = "init")(key_init)
step_scale = ppl.random_variable(tfd.HalfNormal(step_scale_prior), name = "step_scale")(key_step)
intensity = ppl.random_variable(tfd.TransformedDistribution(
hierarchical_random_walk_dist(n, init, step_scale),
bijectors.Exp()),
name = "intensity")(key_intensity)
poi_keys = random.split(key_poi, n)
poi = jax.vmap(lambda ky, x: ppl.random_variable(tfd.Poisson(x), name = "poi")(ky))(poi_keys, intensity)
return poi
Sample some data from model
sampler = ppl.joint_sample(poisson_process)
key_rn = random.PRNGKey(1234)
true_params = sampler(key_rn, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)
plt.plot(true_params['intensity'])
plt.scatter(range(len(true_params['poi'])), true_params['poi'], color='red')
plt.xlabel('time')
plt.ylabel('Intensity')
plt.title('Intensity Random Variable')
plt.show()
This looks reasonable.
Inference
I split out the observed data from the rest of the parameters
true_data = true_params.pop('poi')
true_params
{'init': Array(0.01620068, dtype=float32), 'intensity': Array([2.21524286e+00, 1.02979910e+00, 3.47778827e-01, 1.59668833e-01, 1.43431634e-01, 1.89372957e-01, 1.48013055e-01, 8.88996869e-02, 1.40973523e-01, 1.33820206e-01, 6.49942383e-02, 4.85427566e-02, 1.31719252e-02, 1.61807686e-02, 1.93237811e-02, 7.30752666e-03, 3.75548634e-03, 3.23897717e-03, 4.46824962e-03, 3.59713589e-03, 3.48433293e-03, 4.54167370e-03, 8.35305359e-03, 7.45324651e-03, 1.69865470e-02, 2.82925181e-03, 4.80814092e-03, 5.73506765e-03, 1.29247606e-02, 2.23501474e-02, 2.60949116e-02, 2.78504174e-02, 2.95239929e-02, 3.01535334e-02, 3.17600109e-02, 5.96645549e-02, 1.58876508e-01, 4.15319920e-01, 2.83102959e-01, 3.94434422e-01, 6.21528685e-01, 9.56910610e-01, 4.71480668e-01, 3.51778269e-01, 3.12051624e-01, 3.87135684e-01, 4.41913813e-01, 8.34466696e-01, 1.12293482e+00, 9.62291718e-01, 6.46639347e-01, 1.22468376e+00, 1.33461881e+00, 9.76860523e-01, 1.60133433e+00, 4.31086159e+00, 3.78359699e+00, 4.50091076e+00, 7.61642456e+00, 9.94997692e+00, 1.83034401e+01, 1.86841354e+01, 1.93865471e+01, 4.10644569e+01, 5.11959839e+01, 3.74023285e+01, 1.46664228e+01, 2.73789101e+01, 5.09101982e+01, 2.61694183e+01, 2.90790100e+01, 1.15916996e+01, 1.44228182e+01, 8.16150761e+00, 1.21826038e+01, 8.52718925e+00, 8.82525539e+00, 1.47077036e+01, 1.31940975e+01, 8.21146393e+00, 5.06118011e+00, 3.73972368e+00, 8.66150951e+00, 8.86765766e+00, 1.82184372e+01, 2.03960953e+01, 1.50705147e+01, 3.58565903e+01, 3.94253273e+01, 1.51045656e+01, 1.46200066e+01, 1.30218935e+01, 8.60846615e+00, 6.86474085e+00, 5.52572966e+00, 6.91005898e+00, 4.49717140e+00, 2.01037908e+00, 2.75382376e+00, 3.28753996e+00], dtype=float32), 'step_scale': Array(0.5200044, dtype=float32)}
Then do the usual blackjax approach to sampling (based on their example of using oryx)
def logdensity_fn(params):
theta = dict(params, poi = true_data)
return ppl.joint_log_prob(poisson_process)(theta, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)
ll = logdensity_fn(true_params)
# Array(-50.80758, dtype=float32)
# Warmup
inference_key = jax.random.PRNGKey(12)
rng_key, warmup_key = jax.random.split(inference_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, true_params, 1000)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
# Sampling
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 2000)
Issues
- The main issue is that this silent fails to sample from the posterior (or I'm not understanding the sample structure):
plt.figure(figsize=(12, 6))
plt.plot(true_params['intensity'], label='True Intensity', color='blue')
for i in range(100): # Plotting the first 10 sampled intensities for clarity
plt.plot(states.position['intensity'][i], alpha=0.5)
plt.xlabel('Time')
plt.ylabel('Intensity')
plt.title('True Intensity vs Sampled Intensities')
plt.legend()
plt.show()
- Warnings about
f32conversion e.g.
UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. minval = minval + np.zeros([1] * final_rank, dtype=dtype)
Which suggests that the underlying Poisson distribution is struggling but...
- Errors if enabling
f64conversion, e.g. if the conversion to double precision is allow then the model fails atjoint_samplewith error:
TypeError: Tensor conversion requested dtype <class 'numpy.float32'> for array with dtype float64: Traced<ShapedArray(float64[100])>with<DynamicJaxprTrace>
Steps forward
I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with joint_log_prob in combination with Poisson or the way I've implemented the poisson link.
Hey! I took a quick look at this, and a few things:
- running the code snippets you provided does a better job of fitting the parameters than the image you attached, but it isn't great
- I ran things through bayeux's debugger, and it looks like you probably need to constrain some (all?) of the parameters to be positive.
- When I ran
adamon the problem after constraining the parameters, I found the log likelihood was much larger withintensityparameters that were near zero than the true parameters!
You can try running pip install -Uq bayeux-ml and then
import bayeux as bx
import jax
def transform_fn(params):
params['init'] = jax.nn.softplus(params['init'])
params['step_scale'] = jax.nn.softplus(params['step_scale'])
params['intensity'] = jax.nn.softplus(params['intensity'])
return params
bx_model = bx.Model(logdensity_fn, test_point=true_params, transform_fn=transform_fn)
assert bx_model.mcmc.numpyro_nuts.debug(jax.random.key(0))
opt = bx_model.optimize.optax_adam(jax.random.key(0))
jax.vmap(logdensity_fn)(opt.params)
if you want to see this. My best guess is that you may need to modify the model somewhat!
Hey @ColCarroll ,
Thanks for having a look!
it looks like you probably need to constrain some (all?) of the parameters to be positive.
The maths of the example is:
\begin{aligned}
B_0 &\sim Normal(\mu_0, \sigma_0) \\
\sigma &\sim \text{HalfNormal}(\tau) \\
\epsilon_t &\sim Normal(0,1) \qquad \text{IID }\forall t \\
B_t &= B_0 + \sigma \sum_{s=1}^t \epsilon_s\\
\lambda_t &= \exp(B_t) \\
y_t &\sim \text{Poisson}(\lambda_t).
\end{aligned}
So $B_0$ doesn't need to be positive and the other parameters like $\sigma$ are drawn from non-negative distributions like oryx.distributions.HalfNormal. None of the forward sampling gives me inappropriately signed parameters?
I ran things through bayeux's debugger,
Thanks for flagging this, I'll give it a try!
I found the log likelihood was much larger with intensity parameters that were near zero than the true parameters!
Thats a bad sign...
Ah, right, forward sampling will give you properly signed parameters, but notice that you are directly evaluating the probability of intensity, which must be positive. This turned up for me because bayeux attempts automatic initialization, and this would consistently fail and produce nans for the log prob.
(It looks like init need not be positive).
Ah, right, forward sampling will give you properly signed parameters, but notice that you are directly evaluating the probability of
intensity, which must be positive. This turned up for me becausebayeuxattempts automatic initialization, and this would consistently fail and producenansfor the log prob.(It looks like
initneed not be positive).
Thanks for the insight! I think I'm too used to PPLs that auto- or semiauto-magic variable bijectors.