pymc
pymc copied to clipboard
make vi (posterior) mean and std accessible as a structured xarray
What is this PR about?
For VI the mean and std of approximations is currently only available as an unstructured flat Aesara Variable. This leads to frequent questions on how to extract these properties from the posterior. This PR creates two new properties which evaluate the Aesara Variables and transforms them into a structured xarray Dataset using the available coords
.
See also: https://discourse.pymc.io/t/quality-of-life-improvements-to-advi/10254
Checklist
- [x] Explain important implementation details 👆
- [x] Make sure that the pre-commit linting/style checks pass.
- [x] Link relevant issues (preferably in nice commit messages)
- [x] Are the changes covered by tests and docstrings?
- [x] Fill out the short summary sections 👇
Major / Breaking Changes
- None
Bugfixes / New features
- new feature: get mean and std as xarray data set
Docs / Maintenance
- Included doc strings for existing
mean
,std
, andcov
properties
The error message of "Read the Docs build" don't look like they have anything to do with this PR. Is there something wrong with the doc strings?
@markusschmaus Thanks for your contribution! For the read the docs error yes please ignore that, sorry for the false alarm.
For your code submission, I'll review it now at a "code level" but will defer to my more VI colleagues here for the math and user questions. Which brings me to my next question, right now the PR is marked draft, did you want a review now or were you still planning on working this some more?
Codecov Report
Merging #6086 (3c6af3a) into main (0b191ad) will increase coverage by
0.06%
. The diff coverage is94.44%
.
:exclamation: Current head 3c6af3a differs from pull request most recent head 671cb9c. Consider uploading reports for the commit 671cb9c to get more accurate results
@@ Coverage Diff @@
## main #6086 +/- ##
==========================================
+ Coverage 89.54% 89.60% +0.06%
==========================================
Files 72 72
Lines 12929 12947 +18
==========================================
+ Hits 11577 11601 +24
+ Misses 1352 1346 -6
Impacted Files | Coverage Δ | |
---|---|---|
pymc/variational/opvi.py | 87.24% <94.44%> (+0.20%) |
:arrow_up: |
pymc/step_methods/hmc/base_hmc.py | 90.55% <0.00%> (+0.78%) |
:arrow_up: |
pymc/variational/approximations.py | 90.14% <0.00%> (+2.81%) |
:arrow_up: |
Thanks, I left it at draft level as I was still trying to debug the docs issue. I'm polishing up a few things and then I will change the status.
Should an InferenceData
object be used to store output rather than a raw xarray? Just thinking in terms of consistency with the rest of PyMC, and a natural place to put samples from the approximate posterior.
I thought about using an InferenceData
object, but it doesn't really fit, as it is meant for storing samples and not the parameters of the posterior approximation.
It's already possible to get a true InferenceData
object by calling sample
on the approximation, including samples from the approximate posterior.
I was thinking something along the lines of the "sample_stats" or "observed_data" entries in the InferenceData
object which are not samples, but related variables from the model.
Let's go through the options:
https://python.arviz.org/en/latest/schema/schema.html#schema
-
posterior
,posterior_predictive
,prior
,prior_predictive
,predictions
: No, since all of these are supposed to be samples -
sample_stats_prior
: No, the approximation isn't related to anyprior
samples -
log_likelihood
: No -
observed_data
: No, these is supposed to be data the posterior is conditional on -
predictions_constant_data
: No, since the approximation has nothing to do with any predictions -
constant_data
: No, since the approximations are not data included in the model -
sample_stats
: Probably not, since this is supposed to relate to the samples in theposterior
group, which we don't have
So a straight forward xarray looks best to me.
Can we just get a dictionary? Nevermind you are passing dims around as well.
Yeah, I find the coords just too useful not to use them. I considered returning a dict of numpy arrays when no coords are given, but this would result in an inconsistent return type, which I always find a pain to deal with when a library does this.
Was hoping you'd be able to add custom attributes to InferenceData
but I guess you can't. 😢
Was hoping you'd be able to add custom attributes to
InferenceData
but I guess you can't. 😢
What do you mean? You can last time I checked
The spec is only enforced with a warning:
for key in kwargs:
if key not in SUPPORTED_GROUPS_ALL:
key_list.append(key)
warnings.warn(
f"{key} group is not defined in the InferenceData scheme", UserWarning
)
https://github.com/arviz-devs/arviz/blob/2a7bf0f2cb26bfe273e800406249547507d4fdd4/arviz/data/inference_data.py#L147
So I could ignore this warning and wrap the xarrays in an Inference data object which doesn't conform to the spec, though I still don't see any benefits for doing so. It wouldn't make sense to start sampling just to be able to fill any of the other fields, since the whole point of this PR is to give the user the ability to extract mean and std without sampling.
If it's about the syntax and you prefer approx.params_data["mean"]
to approx.mean_data
, it would be an option to wrap them them in a dictionary.
I meant you can add attributes to one of the "allowed" groups. In this case I was thinking you could add it to the posterior group.
The point of the PR is to avoid sampling just for extracting the mean and std, so there are no posterior samples and no posterior group. I could create an empty group, but I still don't see the benefit.
The point of the PR is to avoid sampling just for extracting the mean and std, so there are no posterior samples and no posterior group. I could create an empty group, but I still don't see the benefit.
Fair enough
@fonnesbeck @ricardoV94 What's the future of this PR?
Can we just get a dictionary? Nevermind you are passing dims around as well.
General PSA, datasets have a dict-like interface so to most ends you can simply ignore the fact you have a dataset and treat it as a dictionary.
Let's rebase and merge this.
@markusschmaus are you good with us merging this as-is, or are there any additional changes you'd like to make? Sorry its taken so long--kind of fell off the radar!
Merge conflicts fixed by #6387