arviz icon indicating copy to clipboard operation
arviz copied to clipboard

[WIP] Add marginal likelihood estimation via bridge sampling

Open karink520 opened this issue 2 years ago • 4 comments

Co-authored-by: @junpenglao

Description

Provides an estimate of the (log) marginal likelihood, estimated using bridge sampling (as described in Gronau, Quentin F., et al., 2017), building on an implementation from @junpenglao. This could be expanded to add Bayes factor functionality if desired.

The bridge sampler uses samples from the posterior, so the log_marginal_likelihood_bridgesampling function takes as a parameter and InferenceData object that has a posterior group, as well as an unnormalized log probability function (e.g. from a pymc model model.logp_array).

Because we fit a multivariate normal proposal distribution to the posterior samples, it is helpful to have samples that are transformed e.g. to have support on the real line instead of on a bounded interval. Although these transformed samples are created as part of the NUTS sampling, I believe they're not currently included in InferenceData (see issue #230 ). So, log_marginal_likelihood_bridgesampling currently takes a dict whose keys are variable names and whose values are the associated transformation (or the identity). You could get this from a pymc model with something like the following, although maybe there's a better way:

def get_transformation_dict_from_model(model):
    """
    Returns a dict giving the transformations for each variable

    Parameters:
    -----------
    model: a PyMC model

    Returns:
    --------
    transformation_dict: dict 
      Keys are (str) names of model variables (their pre-transformation names),
      Values are their associated transformation as a function that 
      (elementwise) transforms an array. If the variable has no transformation
      associated, we use the identity function.
    """

    transformation_dict = {}
    for var_name in model.named_vars:
        if not var_name.endswith('__'): 
            var = getattr(model, var_name)
            transformation = getattr(var, 'transformation', None)
            if transformation is not None:
              transformation_dict[var_name] = transformation.forward_val
            else: # if no transformation, use identity
              transformation_dict[var_name] = lambda x: x
    return transformation_dict

Curious to hear any thoughts or feedback! I'm happy to write tests for this as well, but wanted to wait to get initial feedback before doing so.

Checklist

  • [x] Follows official PR format
  • [x] New features are properly documented (with an example if appropriate)
  • [ ] Includes new or updated tests to cover the new feature
  • [x] Code style correct (follows pylint and black guidelines)
  • [ ] Changes are listed in changelog

karink520 avatar Jun 02 '22 13:06 karink520

Should we use a class approach like we do with reloo? This way different backends only need to create specific methods with uniform parameters and output.

cc @OriolAbril

ahartikainen avatar Jun 06 '22 20:06 ahartikainen

I think we should define how to include unconstrained variables in InferenceData and solve #230 between options 1 or 2. It is turning out to be much more work and inconvenient to go with option 3.

I have skimmed the code and have many ideas, mostly related to using xarray more. But I am not sure it is worth it to start changing things yet until we have decided on the issue.

OriolAbril avatar Jun 07 '22 16:06 OriolAbril

Just checking in to see if there's anything I can be helpful with here!

karink520 avatar Sep 21 '22 00:09 karink520

Very sorry about the other PR taking so long, but it has finally been merged. I can take care of rebasing if it helps, then I'll try and add some high level comments

OriolAbril avatar Dec 29 '22 00:12 OriolAbril