pymc-experimental icon indicating copy to clipboard operation
pymc-experimental copied to clipboard

Pathfinder gives confident wrong answer with small sample prediction

Open fonnesbeck opened this issue 1 year ago • 9 comments

This example is taken from the baseball case study in pymc-examples. We fit a beta-binomial model to some baseball batting data:

data = pd.read_csv(pm.get_data("efron-morris-75-data.tsv"), sep="\t")

N = len(data)
player_names = data["FirstName"] + " " + data["LastName"]
# coords = {"player_names": player_names.tolist()}

with pm.Model() as baseball_model:
    at_bats = pm.MutableData("at_bats", data["At-Bats"].to_numpy())
    n_hits = pm.MutableData("n_hits", data["Hits"].to_numpy())
    baseball_model.add_coord("player_names", player_names, mutable=True)

    phi = pm.Uniform("phi", lower=0.0, upper=1.0)

    kappa_log = pm.Exponential("kappa_log", lam=1.5)
    kappa = pm.Deterministic("kappa", pm.math.exp(kappa_log))

    theta = pm.Beta("theta", alpha=phi * kappa, beta=(1.0 - phi) * kappa, dims="player_names")
    y = pm.Binomial("y", n=at_bats, p=theta, observed=n_hits, dims="player_names")

and then add a prediction for a fictional player that has zero hits in 4 appearances:

with baseball_model:
    theta_new = pm.Beta("theta_new", alpha=phi * kappa, beta=(1.0 - phi) * kappa)
    y_new = pm.Binomial("y_new", n=4, p=theta_new, observed=0)

What should occur (and does with either pymc.sample or pymc.fit) is that since the sample size of y_new is so small, it should be shrunk towards the population mean. Here is the population of players:

410871ea-80e5-44e6-a402-1c8acbd26ba6

and the population mean is given by phi:

112991f7-c433-4480-8026-e95aa154f926

however, the estimate for theta_new is way too large (larger than the most extreme player in the fitting dataset) with a high degree of posterior confidence:

adaf7323-a7f9-4a56-9a6b-1ae270ac6f0c

Running the same model with pm.fit or pm.sample returns more reasonable estimates just under the population mean.

Using PyMC 3.10.1 and pymc-experimental from the main repo.

fonnesbeck avatar Dec 13 '23 02:12 fonnesbeck

Not sure, the pathfinder return result that underestimated kappa and theta: image

But probably this is the property of pathfinder, I dont work with it enough to provide good perspective. @ColCarroll has a bit more experience, maybe he has some idea?

junpenglao avatar Dec 13 '23 10:12 junpenglao

See also the issues I found before with the 8 school example, where it would basically return the initval for whatever mu was: https://gist.github.com/ricardoV94/eafd20ac47d63525253b0a8adf5e5d76

ricardoV94 avatar Dec 13 '23 11:12 ricardoV94

yeah the pathfinder have a jaxopt dependency that have some convergent gap (compare to scipy.optimize.minimize). I think on the blackjax side we can be more explicit on detecting none convergence.

junpenglao avatar Dec 13 '23 11:12 junpenglao

For the intermediate, I suggest adding some noise to the initial position: https://github.com/pymc-devs/pymc-experimental/blob/00d7a2b3cf3379e0a9420fb436667ab781e5a5e7/pymc_experimental/inference/pathfinder.py#L104, so at the very least we can run the pathfinder a couple of times.

junpenglao avatar Dec 13 '23 11:12 junpenglao

You can use this to add jitter to RVs: https://github.com/pymc-devs/pymc/blob/0fd7b9e1d2208f1250b1c804bf5421013dba9023/pymc/initial_point.py#L111

ricardoV94 avatar Dec 13 '23 11:12 ricardoV94

After trying to find in the BlackJAX Pathfinder backend where the calculations causing the poor posterior estimates were coming from and not finding anything, I've decided to compare the Stan Pathfinder estimates with BlackJAX.

The comparison between the Stan and PyMC is in the two notebook links below (apologies for the untidy notebooks and coding, but the images would hopefully provide a good enough summary):

Eightschools Data https://gist.github.com/aphc14/a32d1f81b8993b8cc57867cd4466edbb

MLB Data (from above) https://gist.github.com/aphc14/9f38b2e45fd220ae4bf1eb6b967ca886

The most surprising outcome is Stan's version of Pathfinder also provides a poor estimate of the posterior on these two data sets. When initialising using jitter_rvs for PyMC + BlackJAX, the outputs are somewhat close to Stan outputs.

From these comparisons, is it safe to say there are no big issues with the backend calculations in BlackJAX Pathfinder?

aphc14 avatar Sep 27 '24 11:09 aphc14

Or both are wrong but in different places 😬 Is there any other implementation we can cross reference to?

junpenglao avatar Sep 27 '24 15:09 junpenglao

The Pathfinder paper seems to show decent results for the centered model using multi-path (Fig 19).

fonnesbeck avatar Sep 27 '24 17:09 fonnesbeck

There is another package, although its in Julia, that implements Pathfinder. But since the Pathfinder paper uses Stan, shouldn't we cross-check our results with Stan? I'll code up a comparison in a more streamlined fashion of other scenarios from posteriordb. I could try measuring the performance with the scaled 1−Wasserstein metric against the same or similar reference posterior to see if the PyMC resembles Stan performances. I'll get around to this probably after improving our PyMC implementation.

aphc14 avatar Oct 03 '24 11:10 aphc14