pymc
pymc copied to clipboard
Allow reusing function in `pymc.compute_log_likelihood`
Describe the issue:
Calling pymc.compute_log_likelihood
multiple times on the same model leads to multiple compilations via compile_fn
. This is a time sink (in some applications) that could easily be solved by storing the compiled function.
Reproduceable code example:
import numpy as np
import pymc
n = 5_000
n_feat = 3
X_train = np.random.normal(size=(n, n_feat))
y_train = np.random.normal(size=(n))
with pymc.Model() as model:
# data containers
X = pymc.MutableData("X", X_train)
y = pymc.MutableData("y", y_train)
# priors
intercept = pymc.Normal("intercept", mu=0, sigma=1)
b = pymc.MvNormal("b", mu=np.zeros(n_feat), cov=np.eye(n_feat))
sigma = pymc.HalfCauchy("sigma", beta=10)
mu = intercept + pymc.math.dot(X, b).flatten()
# likelihood
likelihood = pymc.Normal("obs", mu=mu, sigma=sigma, observed=y)
idata = pymc.sample()
def compute_ll_twice(model, idata):
n_test = 5
X_test = np.random.normal(size=(n_test, n_feat))
y_test = np.random.normal(size=(n_test))
with model:
pymc.set_data({"X": X_test, "y": y_test})
for _ in range(5):
out = pymc.compute_log_likelihood(idata, extend_inferencedata=False)
%prun -l 20 -s cumtime compute_ll_twice(model, idata)
Error message:
No response
PyMC version information:
pymc==5.10.1
Context for the issue:
In a research application I am writing, I need to call pymc.compute_log_likelihood
many times (sometimes refitting the model in between). The calls to compile_fn
take up to 60% of my computation time. If this could be easily fixed, it would be extremely helpful for me. Thank you very much in advance!
:tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.
All PyMC Model functions are recompiled everytime, including when you call pm.sample (or sample_posterior_predictive) multiple times with the same model. It's not trivial to know whether a model has changed of if the function we are requesting is equivalent to another that was requested in the past.
The compiled function itself is pretty simple to obtain: https://github.com/pymc-devs/pymc/blob/118be0f23782945dc03c5fb36d58d6ce4a1f619f/pymc/stats/log_likelihood.py#L83-L87
What's a bit tricky is that transformed variables are not available in the trace. compute_log_likelihood
has an ugly hack around this but there's a better solution now.
from pymc.model.conditioning import remove_value_transforms
model_wo_transforms = remove_value_transforms(model)
fn = model_wo_transforms.compile_fn(...)
# Reuse fn...
Then there's a bit of boilerplate to map a function over a trace that we could provide to users
@ricardoV94 That would be a very useful addition in the case where the model has not changed!
Maybe a way could be even to pass the compiled function to the compute_log_likelihood
function as an optional argument?
Closing this is favor of #7177
Thanks for bringing this up!