pymc-marketing
pymc-marketing copied to clipboard
plot_prior_predictive() gives error "RuntimeError: The model hasn't been fit yet, call .fit() first"
Hello!
Using the example mmm notebook on the pymc marketing page, I created a basic model using the DelayedSaturatedMMM class.
I am able to .fit() the model and run through all the functions outlined in the notebook, however when I try to run .plot_prior_predictve() I get the above error. Any ideas why this may be?
You probably have to call model.sample_prior_predictive first, but the error message could be improved, you don't have to fit it first
Thanks for the reply @ricardoV94 .
It seems that calling sample_prior_predictive did not fix the issue. After digging into the plot_prior_predictive it appears that maybe it is related to this line?
It seems to be setting prior_predictive_data equal to self.prior_predictive? But it looks like my fitted model doesn't have the prior_predictive property. Calling mmm.prior_predictive() on my fitted model returns the same runtime error I posted originally.
Replacing that line with the following worked for me me:
prior_predictive_data = mmm.sample_prior_predictive(X, combined=False)
This is because self.idata is always None at that part of the code, so self.prior_predictive will always raise the error (definitely an undesired behavior).
A work-around for now is to build the model manually, and then run plot_prior_predictive:
mmm.build_model(X=X, y=y.to_numpy()) # `to_numpy` because at that point `y` hasn't been cast to numpy array yet, so will fail otherwise
with mmm.model:
mmm.idata = pm.sample_prior_predictive()
mmm.plot_prior_predictive();
However, this solution doesn't work to solve the bug in the codebase, because everything is None at that time -- self.X, self.y and self.model. I'm not sure that's expected, or even desirable. WDYT @ricardoV94 ?
I also think plot_prior_predictive should be extended to plotting some latent variables of interest, instead of only the observed variable
Thanks @AlexAndorra this worked perfectly! Is this the only workaround at the moment?
If I understand correctly, X and y are only needed to get the dates and the preprocessed y to plot the observed data, so the data has not been preprocessed yet because that happens until we call the fit method and not when we call build_model. Should we only use this workaround after fit then?
Just wondering if X and y need to be the X_train and y_train respectively to have the same length for other plots.
Thanks @AlexAndorra this worked perfectly! Is this the only workaround at the moment?
If I understand correctly, X and y are only needed to get the dates and the preprocessed y to plot the observed data, so the data has not been preprocessed yet because that happens until we call the fit method and not when we call build_model. Should we only use this workaround after fit then?
Just wondering if X and y need to be the X_train and y_train respectively to have the same length for other plots.
Hi @AlfredoJF
The X variable does make a difference as it will provide the information for scaling, media inputs, and control variables as well. Since this is a regression model, the y is a function of X. For instance, having media inputs as zero will result in no media contributions in the prior predictive
A similar workflow comes from the lift tests integration of passing X_train to build_model method:
mmm = DelayedSaturatedMMM(...)
mmm.build_model(X_train, y_train) # This does work with y_train = pd.Series now!
mmm.add_lift_test_measurements(...)
mmm.fit(X_train, y_train) # duplicated input as the build model
From https://github.com/pymc-labs/pymc-marketing/blob/916dce42b39c67ee3a04188ba492ac53f5e6504d/pymc_marketing/mmm/base.py#L256C72-L256C72:
def plot_prior_predictive(
self, samples: int = 1_000, **plt_kwargs: Any
) -> plt.Figure:
prior_predictive_data: az.InferenceData = self.prior_predictive
self.prior_predictive is None because pm.sample_prior_predictive() hasn't been called yet. This should be a relatively easy fix.
However, after model.fit() has been called, self.prior_predictive will get wiped from model.idata due to how these properties are currently written in the base ModelBuilder class. We would need to prioritize https://github.com/pymc-labs/pymc-marketing/pull/414 for that one because it's not an easy fix.
In the meantime, to hack together a prior predictive plot, the internals of the above plotting method could probably be copy/pasted into a notebook and tweaked into a solution.