numpyro
numpyro copied to clipboard
AutoContinuous/funsor bug?
Here's a reproducible example that's taken nearly directly from the Gaussian Mixture Model tutorial. The AutoContinuous guide seems to be the failure mode.
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, TraceEnum_ELBO, autoguide
from numpyro.handlers import block, seed
data = jnp.array([0.0, 1.0, 10.0, 11.0, 12.0])
K = 2 # Fixed number of components.
def model(data):
# Global variables.
weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
with numpyro.plate("components", K):
locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))
with numpyro.plate("data", len(data)):
# Local variables.
assignment = numpyro.sample("assignment", dist.Categorical(weights),
infer={"enumerate":"parallel"})
numpyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)
# this works
guide = autoguide.AutoNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi = SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 100, data)
# this fails
guide = autoguide.AutoDiagonalNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi = SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 100, data)
Here's the associated stack trace.
[426](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=425) if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
[427](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=426) for name in msg["args"][0].inputs:
[428](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=427) self._saved_globals += (
--> [429](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=428) (name, _DIM_STACK.global_frame.name_to_dim[name]),
[430](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=429) )
KeyError: 'components'
If I replace the components plate with locs = numpyro.sample("locs", dist.Normal(0.0, 10.0).expand((K,)).to_event(1))
, I get the KeyError on the 'data' plate.
Hi @amifalk, AutoContinuous does not work with enumerated models. We should raise a better error message for this.
Is there a way to add this functionality (even if only for a subset of models), or is it a limitation of numpyro?
Yes, it's the limitation of the blackbox one. It would be much easier to write custom guides for your models.