Log name of RVs that are being resampled in posterior predictive
Right now, the behavior of sample_posterior_predictive is very opaque as to which RVs are being resampled. Any RVs that are downstream of Shared variables or other variables mentioned in var_names, as well as variables missing in the trace will be resampled. We should output their names, similarly to how we show what variables are being sampled in pm.sample
For instance, in the example below it might unclear that y is being resampled, as it does not show in any log messages or in the output of posterior_predictive:
import pymc as pm
with pm.Model() as m:
x = pm.MutableData("x", 0)
y = pm.Normal("y", x)
z = pm.Normal("z", y, observed=0)
idata = pm.sample(chains=1, tune=0, draws=5, random_seed=1)
pm.set_data({"x": 100})
pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=1)
print(idata.posterior_predictive.data_vars) # z (chain, draw) float64 101.8 101.5 ...
print(idata.posterior_predictive["z"].mean().values) # 100.30893015576387
CC @lucianopaz
This does not seem very beginner friendly. Source, I'm in a room with beginners and its not entirely straightforward for them. Still a great feature, I'll remove the label
I've been trying to understand what would be needed to implement this feature. What is not yet clear to me:
-
What is the exact output of
pm.samplethat should be implemented analogously forpm.sample_posterior_predictive? Is it the printed messageNUTS: [y]? -
For the above example, I tried to find out where the relevant variables
[y, z]explicitly occur insample_posterior_predictive. I findvars_to_sample = [z]andvars_in_trace = [y]. However, at the point when theaesarafunction is created, I findzinoutputs, but don't findyanymore in any of the arguments passed toaesara.function: https://github.com/pymc-devs/pymc/blob/a22998341919512b44b32400508d29aa09d26542/pymc/aesaraf.py#L1034-L1040 This makes me think that perhaps sampling ofyis only implicit throughz.owner. Is this conclusion correct? Indeed, when I only draw fromz, I also get the correct posterior predictive:
print(pm.draw(z, draws=1000).mean()) # 99.984246138739
Yes, because y depends on a shared variable, it is resampled. But it doesn't show in the outputs.
However, if this is indeed the source of the issue #6047, we should rethink our strategy. It might be better to change our default and not resample (non-deterministic) variables that exist in the trace and emit some warning like:
"variable y which depends on a shared variable will not be resampled because it's already present in the trace. Include it in var_names if you wish to resample it.”
Although having the warning when you are okay with the behavior is annoying.
We should only be resampling deterministics with shared variables by default, not pre-existing RVs with shared variables
CC @lucianopaz
Is it the printed message NUTS: [y]?
Yes. Something like Sampling: [y, z] in this case