numpyro
numpyro copied to clipboard
`AutoNormal`, `AutoDelta`, and `AutoGuideList` do not support subsamples of variable size.
AutoNormal, AutoDelta, and AutoGuideList raise an exception in SVI when the subsample size varies across different log_density evaluation. Here is an example reproducing the issue (run on master).
import numpyro
from jax import numpy as jnp
def model(n, x=None, subsample_size=None):
mu = numpyro.sample("mu", numpyro.distributions.Normal())
with numpyro.plate("n", n, subsample_size=subsample_size):
numpyro.sample("x", numpyro.distributions.Normal(mu, 1), obs=x)
def demo(guide_cls):
n = 10
x_obs = jnp.zeros(n)
guide = guide_cls(model)
with numpyro.handlers.seed(rng_seed=0):
# Initialize the guide with the full dataset, get a trace, and replay against
# the model.
guide(n, x_obs)
guide_trace = numpyro.handlers.trace(guide).get_trace()
replayed = numpyro.handlers.replay(model, guide_trace)
print("evaluate log density for full data")
numpyro.infer.util.log_density(replayed, (n, x_obs), {}, {})
print("evaluate log density for subsampled data")
numpyro.infer.util.log_density(replayed, (n, x_obs[:3], 3), {}, {})
print("done")
# This works just fine.
demo(numpyro.infer.autoguide.AutoDiagonalNormal)
# This raises an error (see traceback below).
demo(numpyro.infer.autoguide.AutoNormal)
The traceback for the failed call is as follows.
evaluate log density for full data
evaluate log density for subsampled data
.../numpyro/playground/test.py:7: UserWarning: subsample_size does not match len(subsample), 3 vs 10. Did you accidentally use different subsample_size in the model and guide?
with numpyro.plate("n", n, subsample_size=subsample_size):
Traceback (most recent call last):
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 151, in broadcast_shapes
return _broadcast_shapes_cached(*shapes)
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/util.py", line 287, in wrapper
return cached(config.config._trace_context(), *args, **kwargs)
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/util.py", line 280, in cached
return f(*args, **kwargs)
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 157, in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 173, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (10,)]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File ".../numpyro/numpyro/infer/util.py", line 80, in log_density
broadcast_shapes(guide_shape, model_shape)
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 153, in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
File ".../numpyro/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 173, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (10,)]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File ".../numpyro/playground/test.py", line 32, in <module>
demo(numpyro.infer.autoguide.AutoNormal)
File ".../numpyro/playground/test.py", line 27, in demo
numpyro.infer.util.log_density(replayed, (n, x_obs[:3], 3), {}, {})
File ".../numpyro/numpyro/infer/util.py", line 82, in log_density
raise ValueError(
ValueError: Model and guide shapes disagree at site: 'x': (10,) vs (3,)
I think the issue is that these guides use _create_plates which in turn uses prototype traces to determine the subsample size.
https://github.com/pyro-ppl/numpyro/blob/aec6bd58b4cf0c2d81b96e62d4d0cf7af3744885/numpyro/infer/autoguide.py#L108-L113
The prototype traces are of course only created on the first invocation such that there is a discrepancy in the expected subsample size when a different mini-batch size is used. Guides inheriting from AutoContinuous do not call _create_plates and do not use plates in their __call__ method. I couldn't quite figure out why some guides do and some guides don't.
This is a good point. I guess a better check is to make sure that there are no latent variables under the subsample plates. When that is the case, there is no need to specify the create_plates argument.
@tillahoffmann sorry for the last misleading comment. For subsampling, the usage is
create_plates = lambda n, x, subsample_size=None: numpyro.plate("n", n, subsample_size=subsample_size)
AutoNormal(..., create_plates=create_plates)