numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

AutoContinuous/funsor bug?

Open amifalk opened this issue 1 year ago • 3 comments

Here's a reproducible example that's taken nearly directly from the Gaussian Mixture Model tutorial. The AutoContinuous guide seems to be the failure mode.

import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, TraceEnum_ELBO, autoguide
from numpyro.handlers import block, seed

data = jnp.array([0.0, 1.0, 10.0, 11.0, 12.0])

K = 2  # Fixed number of components.

def model(data):
    # Global variables.
    weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
    scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
    
    with numpyro.plate("components", K):
        locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))

    with numpyro.plate("data", len(data)):
        # Local variables.
        assignment = numpyro.sample("assignment", dist.Categorical(weights), 
                                    infer={"enumerate":"parallel"})
        numpyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)
        
# this works
guide = autoguide.AutoNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi = SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 100, data)

# this fails
guide = autoguide.AutoDiagonalNormal(block(seed(model, rng_seed=0), hide=['assignment']))
svi = SVI(model, guide, numpyro.optim.Adam(0.003), TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 100, data)

Here's the associated stack trace.

[426](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=425) if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
    [427](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=426)     for name in msg["args"][0].inputs:
    [428](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=427)         self._saved_globals += (
--> [429](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=428)             (name, _DIM_STACK.global_frame.name_to_dim[name]),
    [430](.../jax/lib/python3.11/site-packages/numpyro/contrib/funsor/enum_messenger.py?line=429)         )

KeyError: 'components'

If I replace the components plate with locs = numpyro.sample("locs", dist.Normal(0.0, 10.0).expand((K,)).to_event(1)), I get the KeyError on the 'data' plate.

amifalk avatar Jan 05 '24 02:01 amifalk

Hi @amifalk, AutoContinuous does not work with enumerated models. We should raise a better error message for this.

fehiepsi avatar Jan 05 '24 16:01 fehiepsi

Is there a way to add this functionality (even if only for a subset of models), or is it a limitation of numpyro?

amifalk avatar Jan 05 '24 18:01 amifalk

Yes, it's the limitation of the blackbox one. It would be much easier to write custom guides for your models.

fehiepsi avatar Jan 05 '24 21:01 fehiepsi