pymc icon indicating copy to clipboard operation
pymc copied to clipboard

make vi (posterior) mean and std accessible as a structured xarray

Open markusschmaus opened this issue 2 years ago • 16 comments

What is this PR about? For VI the mean and std of approximations is currently only available as an unstructured flat Aesara Variable. This leads to frequent questions on how to extract these properties from the posterior. This PR creates two new properties which evaluate the Aesara Variables and transforms them into a structured xarray Dataset using the available coords.

See also: https://discourse.pymc.io/t/quality-of-life-improvements-to-advi/10254

Checklist

Major / Breaking Changes

  • None

Bugfixes / New features

  • new feature: get mean and std as xarray data set

Docs / Maintenance

  • Included doc strings for existing mean, std, and cov properties

markusschmaus avatar Aug 31 '22 12:08 markusschmaus

The error message of "Read the Docs build" don't look like they have anything to do with this PR. Is there something wrong with the doc strings?

markusschmaus avatar Aug 31 '22 14:08 markusschmaus

@markusschmaus Thanks for your contribution! For the read the docs error yes please ignore that, sorry for the false alarm.

For your code submission, I'll review it now at a "code level" but will defer to my more VI colleagues here for the math and user questions. Which brings me to my next question, right now the PR is marked draft, did you want a review now or were you still planning on working this some more?

canyon289 avatar Sep 01 '22 13:09 canyon289

Codecov Report

Merging #6086 (3c6af3a) into main (0b191ad) will increase coverage by 0.06%. The diff coverage is 94.44%.

:exclamation: Current head 3c6af3a differs from pull request most recent head 671cb9c. Consider uploading reports for the commit 671cb9c to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6086      +/-   ##
==========================================
+ Coverage   89.54%   89.60%   +0.06%     
==========================================
  Files          72       72              
  Lines       12929    12947      +18     
==========================================
+ Hits        11577    11601      +24     
+ Misses       1352     1346       -6     
Impacted Files Coverage Δ
pymc/variational/opvi.py 87.24% <94.44%> (+0.20%) :arrow_up:
pymc/step_methods/hmc/base_hmc.py 90.55% <0.00%> (+0.78%) :arrow_up:
pymc/variational/approximations.py 90.14% <0.00%> (+2.81%) :arrow_up:

codecov[bot] avatar Sep 01 '22 13:09 codecov[bot]

Thanks, I left it at draft level as I was still trying to debug the docs issue. I'm polishing up a few things and then I will change the status.

markusschmaus avatar Sep 01 '22 13:09 markusschmaus

Should an InferenceData object be used to store output rather than a raw xarray? Just thinking in terms of consistency with the rest of PyMC, and a natural place to put samples from the approximate posterior.

fonnesbeck avatar Sep 02 '22 15:09 fonnesbeck

I thought about using an InferenceData object, but it doesn't really fit, as it is meant for storing samples and not the parameters of the posterior approximation.

It's already possible to get a true InferenceData object by calling sample on the approximation, including samples from the approximate posterior.

markusschmaus avatar Sep 02 '22 17:09 markusschmaus

I was thinking something along the lines of the "sample_stats" or "observed_data" entries in the InferenceData object which are not samples, but related variables from the model.

fonnesbeck avatar Sep 02 '22 18:09 fonnesbeck

Let's go through the options:

https://python.arviz.org/en/latest/schema/schema.html#schema

  • posterior, posterior_predictive, prior, prior_predictive, predictions: No, since all of these are supposed to be samples
  • sample_stats_prior: No, the approximation isn't related to any prior samples
  • log_likelihood: No
  • observed_data: No, these is supposed to be data the posterior is conditional on
  • predictions_constant_data: No, since the approximation has nothing to do with any predictions
  • constant_data: No, since the approximations are not data included in the model
  • sample_stats: Probably not, since this is supposed to relate to the samples in the posterior group, which we don't have

So a straight forward xarray looks best to me.

markusschmaus avatar Sep 02 '22 19:09 markusschmaus

Can we just get a dictionary? Nevermind you are passing dims around as well.

ricardoV94 avatar Sep 02 '22 20:09 ricardoV94

Yeah, I find the coords just too useful not to use them. I considered returning a dict of numpy arrays when no coords are given, but this would result in an inconsistent return type, which I always find a pain to deal with when a library does this.

markusschmaus avatar Sep 03 '22 08:09 markusschmaus

Was hoping you'd be able to add custom attributes to InferenceData but I guess you can't. 😢

fonnesbeck avatar Sep 10 '22 21:09 fonnesbeck

Was hoping you'd be able to add custom attributes to InferenceData but I guess you can't. 😢

What do you mean? You can last time I checked

ricardoV94 avatar Sep 11 '22 06:09 ricardoV94

The spec is only enforced with a warning:

for key in kwargs:
    if key not in SUPPORTED_GROUPS_ALL:
        key_list.append(key)
        warnings.warn(
            f"{key} group is not defined in the InferenceData scheme", UserWarning
        )

https://github.com/arviz-devs/arviz/blob/2a7bf0f2cb26bfe273e800406249547507d4fdd4/arviz/data/inference_data.py#L147

So I could ignore this warning and wrap the xarrays in an Inference data object which doesn't conform to the spec, though I still don't see any benefits for doing so. It wouldn't make sense to start sampling just to be able to fill any of the other fields, since the whole point of this PR is to give the user the ability to extract mean and std without sampling.

If it's about the syntax and you prefer approx.params_data["mean"] to approx.mean_data, it would be an option to wrap them them in a dictionary.

markusschmaus avatar Sep 12 '22 08:09 markusschmaus

I meant you can add attributes to one of the "allowed" groups. In this case I was thinking you could add it to the posterior group.

ricardoV94 avatar Sep 12 '22 09:09 ricardoV94

The point of the PR is to avoid sampling just for extracting the mean and std, so there are no posterior samples and no posterior group. I could create an empty group, but I still don't see the benefit.

markusschmaus avatar Sep 12 '22 12:09 markusschmaus

The point of the PR is to avoid sampling just for extracting the mean and std, so there are no posterior samples and no posterior group. I could create an empty group, but I still don't see the benefit.

Fair enough

ricardoV94 avatar Sep 12 '22 12:09 ricardoV94

@fonnesbeck @ricardoV94 What's the future of this PR?

markusschmaus avatar Oct 02 '22 23:10 markusschmaus

Can we just get a dictionary? Nevermind you are passing dims around as well.

General PSA, datasets have a dict-like interface so to most ends you can simply ignore the fact you have a dataset and treat it as a dictionary.

OriolAbril avatar Oct 10 '22 22:10 OriolAbril

Let's rebase and merge this.

ghost avatar Nov 29 '22 14:11 ghost

@markusschmaus are you good with us merging this as-is, or are there any additional changes you'd like to make? Sorry its taken so long--kind of fell off the radar!

fonnesbeck avatar Nov 29 '22 14:11 fonnesbeck

Merge conflicts fixed by #6387

fonnesbeck avatar Dec 12 '22 17:12 fonnesbeck