bambi icon indicating copy to clipboard operation
bambi copied to clipboard

Saving and Loading Models

Open hxk1633 opened this issue 1 year ago • 7 comments

Is there a way to save a fitted model to disk and then load it later to make predictions?

hxk1633 avatar May 03 '23 23:05 hxk1633

Currently there's no way to do so. You could write something to store the inference data and the metadata of the model (the formula, the model family, the priors, etc.) and then load it again. It's along the lines of what I would be interested in doing, but I haven't had time lately.

tomicapretto avatar May 04 '23 17:05 tomicapretto

@hxk1633 You can try something like this from pymc-experimental

5hv5hvnk avatar Jul 07 '23 12:07 5hv5hvnk

Currently there's no way to do so. You could write something to store the inference data and the metadata of the model (the formula, the model family, the priors, etc.) and then load it again. It's along the lines of what I would be interested in doing, but I haven't had time lately.

Could you provide some details about how I could go about implementing something like this? Currently we have a model built with Bambi and we need to do predictions in production on new data that changes every few hours. I really need a way to serialise the Bambi model so it can be loaded again and serve predictions via an API.

humana avatar Apr 05 '24 12:04 humana

@humana have a look at the following example

Script 1 This is where you first created and "trained" your model:

import pickle

import arviz as az
import bambi as bmb

df = bmb.load_data("my_data")
df.head()

# Store all the arguments you pass to `bmb.Model()` in a dict that is pickled
family = "gaussian"
formula = "y ~ x + z"
priors = {
    "Intercept": bmb.Prior("Normal", mu=0.5, sigma=1),
    "x": bmb.Prior("Normal", mu=0, sigma=1),
    "z": bmb.Prior("Normal", mu=0, sigma=2),
}

args_dict = {
    "formula": formula,
    "data": df,
    "family": family,
    "priors": priors
}

# Create and fit model
model = bmb.Model(**args_dict)
idata = model.fit(random_seed=1234)

# Store things on disk
# Model metadata (required to re-create the model)
with open("model_args_dict.pickle", "wb") as handle:
    pickle.dump(args_dict, handle)

# InferenceData object (contains draws from the posterior)
idata.to_netcdf("idata.nc")

Script 2 This is what you use to obtain predictions on new datasets without having to build/fit the underlying PyMC model again

import pickle

import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd

# This is a new data frame
df_new = pd.DataFrame({"x": np.random.normal(size=10), "z": np.random.normal(size=10)})

# Load original arguments
with open("model_args_dict.pickle", "rb") as handle:
    args_dict_loaded = pickle.load(handle)

# Re-create the Bambi model (but this doesn't recreate the PyMC model unless you .build() it)
model_loaded = bmb.Model(**args_dict_loaded)

# Load the posterior draws (and other data too)
idata_loaded = az.from_netcdf("idata.nc")

# Use the model to predict on the new dataset
model_loaded.predict(idata_loaded, data=df_new, inplace=False, kind="pps")

tomicapretto avatar Apr 09 '24 13:04 tomicapretto

@GStechschulte I'm thinking making this pattern more visible on our docs could help more people, what do you think?

Also, it's actually quite fast (as it doesn't have to compile many things on the PyMC side)

tomicapretto avatar Apr 09 '24 13:04 tomicapretto

Thank you, this is pretty much what I came up with myself after reading through what the predict function would need, but I was worried I might have missed something because it looked too simple. Very helpful.

humana avatar Apr 09 '24 14:04 humana

Thank you, this is pretty much what I came up with myself after reading through what the predict function would need, but I was worried I might have missed something because it looked too simple. Very helpful.

Great! This is possible because Bambi "knows" how to compute a lot of things without relying on PyMC/PyTensor graph structure. If we have a Bambi model and the inference data object, we can generate predictions without having to build the PyMC model at all. However, we do use the draw() function from PyMC to get the draws from a PyMC distribution. We could avoid this step in many cases, but we would need to maintain a larger and more confusing codebase.

I'm happy you were able to work it out. Just let me know if you have any other question.

tomicapretto avatar Apr 09 '24 14:04 tomicapretto