pymc icon indicating copy to clipboard operation
pymc copied to clipboard

ENH: Support foward sampling via JAX

Open jessegrabowski opened this issue 8 months ago • 8 comments

Before

import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data

with pm.Model() as m:
    ...

with freeze_dims_and_data(m):
    idata_prior = pm.sample_prior_predictive(compile_kwargs={"mode":"JAX"})

with m:
    idata = pm.sample()

with freeze_dims_and_data(m):
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, compile_kwargs = {"mode":"JAX"})

After

import pymc as pm

with pm.Model() as m:
   ...
   idata_prior = pm.sample_prior_predictive(mode="JAX")
   idata = pm.sample()
   idata = pm.sample_posterior_predictive(idata, mode="JAX", extend_inferencedata=True)

Context for the issue:

For models involving scan or multivariate-normal distributions, you get big speedups by passing compile_kwargs={'mode':'JAX'} to pm.sample_prior_predictive and pm.sample_posterior_predictive. This has already proven useful in statespace modeling (in pymc-experimental, https://github.com/pymc-devs/pymc-experimental/pull/346) and instrumental variable modeling (in casualpy, https://github.com/pymc-labs/CausalPy/pull/345). In each of these cases using the JAX backend offers significant speedups, and is a highly desirable feature.

This was technically never a supported feature, but it could be made to work by consciously specifying the whole model to be static (e.g. using pm.ConstantData and avoiding mutable_kwargs). After #7047 this is obviously no longer possible. The work-around is to use freeze_dims_and_data, but this is somewhat cumbersome, especially with prior predictive sampling, where a typical workflow calls pm.sample_prior_predictive in the model block at construction time. I have also come up with cases where freeze_dims_and_data fails. A trivial example is in predictive modeling using pm.Flat dummies -- this adds non-None entries to model.rvs_to_initial_values, causing model_to_fgraph to fail.

My proposal would be to simply add a "freezing" step to compile_forward_sampling_function. This would alleviate the need for users to be aware of the freeze_dims_and_data helper function, allow JAX forward sampling without breaking out of a single model context, and also support any future backend that requires all shape information to be known.

I would also propose to officially support and expose alternative forward sampling backends by promoting backend= or mode= to a kwarg in pm.sample_*_predictive, rather than hiding it inside compile_kwargs.

jessegrabowski avatar Jun 04 '24 11:06 jessegrabowski