pymc
pymc copied to clipboard
ENH: Support foward sampling via JAX
Before
import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
with pm.Model() as m:
...
with freeze_dims_and_data(m):
idata_prior = pm.sample_prior_predictive(compile_kwargs={"mode":"JAX"})
with m:
idata = pm.sample()
with freeze_dims_and_data(m):
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, compile_kwargs = {"mode":"JAX"})
After
import pymc as pm
with pm.Model() as m:
...
idata_prior = pm.sample_prior_predictive(mode="JAX")
idata = pm.sample()
idata = pm.sample_posterior_predictive(idata, mode="JAX", extend_inferencedata=True)
Context for the issue:
For models involving scan or multivariate-normal distributions, you get big speedups by passing compile_kwargs={'mode':'JAX'}
to pm.sample_prior_predictive
and pm.sample_posterior_predictive
. This has already proven useful in statespace modeling (in pymc-experimental, https://github.com/pymc-devs/pymc-experimental/pull/346) and instrumental variable modeling (in casualpy, https://github.com/pymc-labs/CausalPy/pull/345). In each of these cases using the JAX backend offers significant speedups, and is a highly desirable feature.
This was technically never a supported feature, but it could be made to work by consciously specifying the whole model to be static (e.g. using pm.ConstantData
and avoiding mutable_kwargs
). After #7047 this is obviously no longer possible. The work-around is to use freeze_dims_and_data
, but this is somewhat cumbersome, especially with prior predictive sampling, where a typical workflow calls pm.sample_prior_predictive
in the model block at construction time. I have also come up with cases where freeze_dims_and_data
fails. A trivial example is in predictive modeling using pm.Flat
dummies -- this adds non-None
entries to model.rvs_to_initial_values
, causing model_to_fgraph
to fail.
My proposal would be to simply add a "freezing" step to compile_forward_sampling_function
. This would alleviate the need for users to be aware of the freeze_dims_and_data
helper function, allow JAX forward sampling without breaking out of a single model context, and also support any future backend that requires all shape information to be known.
I would also propose to officially support and expose alternative forward sampling backends by promoting backend=
or mode=
to a kwarg in pm.sample_*_predictive
, rather than hiding it inside compile_kwargs
.