oryx icon indicating copy to clipboard operation
oryx copied to clipboard

Problem conditioning on `vmap` Poisson random variables

Open SamuelBrand1 opened this issue 1 year ago • 4 comments

Hi everyone,

I'm finding oryx a really clean approach to implementing a PPL. However, I'm confused about conditional sampling.

Poisson process with exp random walk intensity

As an attempt to get into the structure of oryx I'm trying to sample from a probabilistic program which represents:

  1. A hierarchical random walk (that is a random walk where the parameters are themselves random variables)
  2. A further Exp transform on the random walk represents the intensity of a Poisson process. This is observed.
  3. Inference done with NUTS from blackjax

Code

Dependencies

import jax
# jax.config.update("jax_enable_x64", True)

import oryx.core.ppl as ppl
import oryx.bijectors as bijectors
import oryx.distributions as tfd
import blackjax


import jax.numpy as jnp
import jax.random as random

from functools import partial
import matplotlib.pyplot as plt

Prob Program Note that I've implemented the link as a vmap over intensity representing conditional independence of observations.

@partial(jax.jit, static_argnames=["n"])
def hierarchical_random_walk_dist(n, init, step_scale):
    rw_transformation = bijectors.Chain([bijectors.Shift(init), bijectors.Scale(step_scale), bijectors.Cumsum()])
    return tfd.TransformedDistribution(tfd.MultivariateNormalDiag(jnp.zeros(n), jnp.ones(n)), rw_transformation)

def poisson_process(key, n, init_prior_loc, init_prior_scale, step_scale_prior):
    key_poi, key_intensity, key_init, key_step = random.split(key, 4)
    init = ppl.random_variable(tfd.Normal(init_prior_loc, init_prior_scale), name = "init")(key_init)
    step_scale = ppl.random_variable(tfd.HalfNormal(step_scale_prior), name = "step_scale")(key_step)
    intensity = ppl.random_variable(tfd.TransformedDistribution(
        hierarchical_random_walk_dist(n, init, step_scale),
        bijectors.Exp()),
        name = "intensity")(key_intensity)
    poi_keys = random.split(key_poi, n)
    poi = jax.vmap(lambda ky, x: ppl.random_variable(tfd.Poisson(x), name = "poi")(ky))(poi_keys, intensity)
    return poi

Sample some data from model

sampler = ppl.joint_sample(poisson_process)
key_rn = random.PRNGKey(1234)
true_params = sampler(key_rn, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)

plt.plot(true_params['intensity'])
plt.scatter(range(len(true_params['poi'])), true_params['poi'], color='red')
plt.xlabel('time')
plt.ylabel('Intensity')
plt.title('Intensity Random Variable')
plt.show()

Image

This looks reasonable.

Inference

I split out the observed data from the rest of the parameters

true_data = true_params.pop('poi')
true_params

{'init': Array(0.01620068, dtype=float32), 'intensity': Array([2.21524286e+00, 1.02979910e+00, 3.47778827e-01, 1.59668833e-01, 1.43431634e-01, 1.89372957e-01, 1.48013055e-01, 8.88996869e-02, 1.40973523e-01, 1.33820206e-01, 6.49942383e-02, 4.85427566e-02, 1.31719252e-02, 1.61807686e-02, 1.93237811e-02, 7.30752666e-03, 3.75548634e-03, 3.23897717e-03, 4.46824962e-03, 3.59713589e-03, 3.48433293e-03, 4.54167370e-03, 8.35305359e-03, 7.45324651e-03, 1.69865470e-02, 2.82925181e-03, 4.80814092e-03, 5.73506765e-03, 1.29247606e-02, 2.23501474e-02, 2.60949116e-02, 2.78504174e-02, 2.95239929e-02, 3.01535334e-02, 3.17600109e-02, 5.96645549e-02, 1.58876508e-01, 4.15319920e-01, 2.83102959e-01, 3.94434422e-01, 6.21528685e-01, 9.56910610e-01, 4.71480668e-01, 3.51778269e-01, 3.12051624e-01, 3.87135684e-01, 4.41913813e-01, 8.34466696e-01, 1.12293482e+00, 9.62291718e-01, 6.46639347e-01, 1.22468376e+00, 1.33461881e+00, 9.76860523e-01, 1.60133433e+00, 4.31086159e+00, 3.78359699e+00, 4.50091076e+00, 7.61642456e+00, 9.94997692e+00, 1.83034401e+01, 1.86841354e+01, 1.93865471e+01, 4.10644569e+01, 5.11959839e+01, 3.74023285e+01, 1.46664228e+01, 2.73789101e+01, 5.09101982e+01, 2.61694183e+01, 2.90790100e+01, 1.15916996e+01, 1.44228182e+01, 8.16150761e+00, 1.21826038e+01, 8.52718925e+00, 8.82525539e+00, 1.47077036e+01, 1.31940975e+01, 8.21146393e+00, 5.06118011e+00, 3.73972368e+00, 8.66150951e+00, 8.86765766e+00, 1.82184372e+01, 2.03960953e+01, 1.50705147e+01, 3.58565903e+01, 3.94253273e+01, 1.51045656e+01, 1.46200066e+01, 1.30218935e+01, 8.60846615e+00, 6.86474085e+00, 5.52572966e+00, 6.91005898e+00, 4.49717140e+00, 2.01037908e+00, 2.75382376e+00, 3.28753996e+00], dtype=float32), 'step_scale': Array(0.5200044, dtype=float32)}

Then do the usual blackjax approach to sampling (based on their example of using oryx)

def logdensity_fn(params):
    theta = dict(params, poi = true_data)
    return ppl.joint_log_prob(poisson_process)(theta, n=100, init_prior_loc=0.0, init_prior_scale=0.05, step_scale_prior=0.25)

ll = logdensity_fn(true_params)
# Array(-50.80758, dtype=float32)

# Warmup
inference_key = jax.random.PRNGKey(12)
rng_key, warmup_key = jax.random.split(inference_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, true_params, 1000)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

# Sampling

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos

rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 2000)

Issues

  1. The main issue is that this silent fails to sample from the posterior (or I'm not understanding the sample structure):
plt.figure(figsize=(12, 6))
plt.plot(true_params['intensity'], label='True Intensity', color='blue')
for i in range(100):  # Plotting the first 10 sampled intensities for clarity
    plt.plot(states.position['intensity'][i], alpha=0.5)
plt.xlabel('Time')
plt.ylabel('Intensity')
plt.title('True Intensity vs Sampled Intensities')
plt.legend()
plt.show()

Image

  1. Warnings about f32 conversion e.g.

UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. minval = minval + np.zeros([1] * final_rank, dtype=dtype)

Which suggests that the underlying Poisson distribution is struggling but...

  1. Errors if enabling f64 conversion, e.g. if the conversion to double precision is allow then the model fails at joint_sample with error:

TypeError: Tensor conversion requested dtype <class 'numpy.float32'> for array with dtype float64: Traced<ShapedArray(float64[100])>with<DynamicJaxprTrace>

Steps forward

I don't have a huge amount of JAX/oryx experience, therefore, it would be great if someone could point out if I've made a glaring error. Or if there is some kind of issue with joint_log_prob in combination with Poisson or the way I've implemented the poisson link.

SamuelBrand1 avatar Feb 03 '25 12:02 SamuelBrand1

Hey! I took a quick look at this, and a few things:

  1. running the code snippets you provided does a better job of fitting the parameters than the image you attached, but it isn't great
  2. I ran things through bayeux's debugger, and it looks like you probably need to constrain some (all?) of the parameters to be positive.
  3. When I ran adam on the problem after constraining the parameters, I found the log likelihood was much larger with intensity parameters that were near zero than the true parameters!

You can try running pip install -Uq bayeux-ml and then

import bayeux as bx
import jax

def transform_fn(params):
  params['init'] = jax.nn.softplus(params['init'])
  params['step_scale'] = jax.nn.softplus(params['step_scale'])
  params['intensity'] = jax.nn.softplus(params['intensity'])
  return params

bx_model = bx.Model(logdensity_fn, test_point=true_params, transform_fn=transform_fn)

assert bx_model.mcmc.numpyro_nuts.debug(jax.random.key(0))

opt = bx_model.optimize.optax_adam(jax.random.key(0))

jax.vmap(logdensity_fn)(opt.params)

if you want to see this. My best guess is that you may need to modify the model somewhat!

Image

ColCarroll avatar Feb 03 '25 20:02 ColCarroll

Hey @ColCarroll ,

Thanks for having a look!

it looks like you probably need to constrain some (all?) of the parameters to be positive.

The maths of the example is:

\begin{aligned}
B_0 &\sim Normal(\mu_0, \sigma_0) \\
\sigma &\sim \text{HalfNormal}(\tau) \\
\epsilon_t &\sim Normal(0,1) \qquad \text{IID }\forall t  \\
B_t &= B_0 + \sigma \sum_{s=1}^t \epsilon_s\\
\lambda_t &= \exp(B_t) \\
y_t &\sim \text{Poisson}(\lambda_t).
\end{aligned}

So $B_0$ doesn't need to be positive and the other parameters like $\sigma$ are drawn from non-negative distributions like oryx.distributions.HalfNormal. None of the forward sampling gives me inappropriately signed parameters?

I ran things through bayeux's debugger,

Thanks for flagging this, I'll give it a try!

I found the log likelihood was much larger with intensity parameters that were near zero than the true parameters!

Thats a bad sign...

SamuelBrand1 avatar Feb 03 '25 22:02 SamuelBrand1

Ah, right, forward sampling will give you properly signed parameters, but notice that you are directly evaluating the probability of intensity, which must be positive. This turned up for me because bayeux attempts automatic initialization, and this would consistently fail and produce nans for the log prob.

(It looks like init need not be positive).

ColCarroll avatar Feb 04 '25 02:02 ColCarroll

Ah, right, forward sampling will give you properly signed parameters, but notice that you are directly evaluating the probability of intensity, which must be positive. This turned up for me because bayeux attempts automatic initialization, and this would consistently fail and produce nans for the log prob.

(It looks like init need not be positive).

Thanks for the insight! I think I'm too used to PPLs that auto- or semiauto-magic variable bijectors.

SamuelBrand1 avatar Feb 04 '25 08:02 SamuelBrand1