arviz
arviz copied to clipboard
[WIP] Add marginal likelihood estimation via bridge sampling
Co-authored-by: @junpenglao
Description
Provides an estimate of the (log) marginal likelihood, estimated using bridge sampling (as described in Gronau, Quentin F., et al., 2017), building on an implementation from @junpenglao. This could be expanded to add Bayes factor functionality if desired.
The bridge sampler uses samples from the posterior, so the log_marginal_likelihood_bridgesampling
function takes as a parameter and InferenceData object that has a posterior group, as well as an unnormalized log probability function (e.g. from a pymc model model.logp_array
).
Because we fit a multivariate normal proposal distribution to the posterior samples, it is helpful to have samples that are transformed e.g. to have support on the real line instead of on a bounded interval. Although these transformed samples are created as part of the NUTS sampling, I believe they're not currently included in InferenceData (see issue #230 ). So, log_marginal_likelihood_bridgesampling
currently takes a dict whose keys are variable names and whose values are the associated transformation (or the identity). You could get this from a pymc model with something like the following, although maybe there's a better way:
def get_transformation_dict_from_model(model):
"""
Returns a dict giving the transformations for each variable
Parameters:
-----------
model: a PyMC model
Returns:
--------
transformation_dict: dict
Keys are (str) names of model variables (their pre-transformation names),
Values are their associated transformation as a function that
(elementwise) transforms an array. If the variable has no transformation
associated, we use the identity function.
"""
transformation_dict = {}
for var_name in model.named_vars:
if not var_name.endswith('__'):
var = getattr(model, var_name)
transformation = getattr(var, 'transformation', None)
if transformation is not None:
transformation_dict[var_name] = transformation.forward_val
else: # if no transformation, use identity
transformation_dict[var_name] = lambda x: x
return transformation_dict
Curious to hear any thoughts or feedback! I'm happy to write tests for this as well, but wanted to wait to get initial feedback before doing so.
Checklist
Should we use a class approach like we do with reloo? This way different backends only need to create specific methods with uniform parameters and output.
cc @OriolAbril
I think we should define how to include unconstrained variables in InferenceData and solve #230 between options 1 or 2. It is turning out to be much more work and inconvenient to go with option 3.
I have skimmed the code and have many ideas, mostly related to using xarray more. But I am not sure it is worth it to start changing things yet until we have decided on the issue.
Just checking in to see if there's anything I can be helpful with here!
Very sorry about the other PR taking so long, but it has finally been merged. I can take care of rebasing if it helps, then I'll try and add some high level comments