probability icon indicating copy to clipboard operation
probability copied to clipboard

Computing log_prob for tfd.Sample() with a different number of samples

Open nick-ponvert opened this issue 2 years ago • 0 comments

I am interested in being able to construct a joint distribution (I use JointDistributionCoroutineAutobatched) for use in regression modeling which includes predictors as part of the model specification. I want to do this so that I can use the same model to construct a generative dataset, test inference code, and then pin to real predictors and observations for inference. Here is an example to help illustrate what I am trying to do.

In this example, we are regressing kcal_per_gram against neocortex_pct (from Statistical Rethinking chapter 5).

@tfd.JointDistributionCoroutineAutoBatched
def model():
    
    # Generative neocortex_pct
    mu_N = yield tfd.Normal(0, 0.2, name='mu_N')
    sigma_N = yield tfd.Exponential(1, name='sigma_N')
    neocortex_pct = yield tfd.Sample(tfd.Normal(mu_N, sigma_N), sample_shape=20, name='neocortex_pct')
    
    intercept = yield tfd.Normal(0, 0.2, name='intercept')
    beta_N = yield tfd.Normal(0, 0.5, name='beta_N')
    mu = intercept + beta_N * neocortex_pct
    
    sigma = yield tfd.Exponential(1, name='sigma')
    kcal_per_gram = yield tfd.Normal(mu, sigma, name='kcal_per_gram')

This allows constructing a generative dataset just by taking a prior sample. You can even set a beta_N value during the sampling and then test whether your algorithm can recover it (I am using the JAX substrate):

key, prior_sample_key = random.split(key)
prior_samples = model.sample(seed=key, beta_N=2)

The issue I am getting is that I then want to condition this distribution on real data and run the inference algorithm. My preferred way of doing this would be to use experimental_pin()

data_dict = {'neocortex_pct': data_df['neocortex_pct'].values, 'kcal_per_gram': data_df['kcal_per_gram'].values})
model_pinned = model.experimental_pin(data_dict)

Then use model_pinned.log_prob() along with model_pinned.experimental_default_event_space_bijector() to either do Laplace approximation or run an MCMC chain. The issue is that if there are a different number of samples in my dataset than the sample_shape that I used when constructing the JointDistribution, I get broadcasting errors. Ideally I would like to be able to use another number of samples later - the one defined in the JointDistribution being used for generating data, but the ability to then do inference with whatever number of samples. Is this something that can be achieved via broadcasting somehow? If not, is there a different way that would be suggested to achieve what I'm looking for? Thanks in advance, and let me know if I can provide any more info.

nick-ponvert avatar Feb 24 '24 07:02 nick-ponvert