pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Allow partial imputation with `pm.observe`

Open ricardoV94 opened this issue 1 year ago • 1 comments

One tricky thing will be to work in conjunction with #6932

Partial imputation is a model transformation that happens usually at model.register_rv and creates two model RVs (the observed and unobserved components) that are the joined together in a deterministic (with the original name) to look like a single entity in case the variable is used downstream elsewhere (or just so it shows in the trace)

What should happen when the user calls observe twice with nan on the original named variable? We don't want to partially observe the deterministic that joins the two components. Also if the second observe had no nan, do we want to remove the useless imputation? How can we do that?

ricardoV94 avatar Mar 20 '24 17:03 ricardoV94

Snippet from @wd60622 for how to reveal the missing functionality:

import pymc as pm
import numpy as np

import matplotlib.pyplot as plt

import arviz as az


def normal_declaration(data):
    coords = {
        "idx": range(len(data)),
    }
    with pm.Model(coords=coords) as model:
        pm.Normal(
            "obs",
            mu=pm.Normal("mu"),
            sigma=pm.HalfNormal("sigma"),
            observed=data,
            dims="idx",
        )

    return model


def work_around(data):
    coords = {
        "idx": range(len(data)),
    }
    with pm.Model(coords=coords) as generative_model:
        pm.Normal(
            "obs",
            mu=pm.Normal("mu"),
            sigma=pm.HalfNormal("sigma"),
            dims="idx",
        )

    return pm.observe(generative_model, {"obs": data})

seed = sum(map(ord, "impute observe bug"))
rng = np.random.default_rng(seed)

mu = 5
sigma = 0.25

data = rng.normal(mu, sigma, size=250)

missing_idx = rng.choice([True, False, False, False], size=data.shape)
data[missing_idx] = np.nan

with normal_declaration(data):
    idata = pm.sample()

with work_around(data):
    # SamplingError: Initial evaluation of model at starting point failed!
    idata_workaround = pm.sample()

ricardoV94 avatar Jul 26 '24 10:07 ricardoV94