numpyro
numpyro copied to clipboard
initialize_model() fails with batch size of 1 for model with discrete variables
Hi,
I'm having issues using NumPyro's latest version of initialize_model() on a model with discrete variables. The model samples random variables within a plate context with a data-dependent size. When the given data is of size $N > 1$, everything works fine. However, if I pass a dataset with a single data point ($N = 1$), the execution of initialize_model fails. Apparently, the funsor API expected the computed log-probability of a sampled variable to be a scalar, when in fact it's a single-element 1D array for this case.
Minimal working example:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.distributions import Bernoulli, MultivariateNormal
from numpyro.infer.util import initialize_model
def model(X: jnp.ndarray, y: jnp.ndarray, z: jnp.ndarray):
phi = numpyro.sample('phi', MultivariateNormal(jnp.zeros(2), jnp.eye(2)))
n_data = X.shape[-2]
with numpyro.plate('individual', n_data):
weights = jnp.tensordot(X, phi, axes=1)
numpyro.sample('labels', Bernoulli(probs=(1/(1 + jnp.exp(-weights)))), obs=z, infer={'enumerate': 'parallel'})
coefficients = numpyro.sample('coefficients', MultivariateNormal(jnp.zeros(2), jnp.eye(2)))
numpyro.sample('responses', MultivariateNormal(X * coefficients, jnp.eye(2)), obs=y)
if __name__ == '__main__':
rng_key = jax.random.PRNGKey(0)
# This line runs without problems:
res = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((10, 2)), 'z': None, 'y': None})
# This line fails
res1 = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((1, 2)), 'z': None, 'y': None})
print('Done')
Execution output:
Traceback (most recent call last):
File "/home/rafael/Projects/trajectory-clustering/funsors_issue.py", line 26, in <module>
res1 = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((1, 2)), 'z': None, 'y': None})
File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 654, in initialize_model
(init_params, pe, grad), is_valid = find_valid_initial_params(
File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 395, in find_valid_initial_params
(init_params, pe, z_grad), is_valid = _find_valid_params(
File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 381, in _find_valid_params
_, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 366, in body_fn
pe, z_grad = value_and_grad(potential_fn)(params)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/api.py", line 1063, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/api.py", line 2558, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/ad.py", line 133, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/ad.py", line 122, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/profiler.py", line 312, in wrapper
return func(*args, **kwargs)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 621, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 248, in potential_energy
log_joint, model_trace = log_density_(
File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 270, in log_density
result, model_trace, _ = _enum_log_density(
File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 181, in _enum_log_density
log_prob_factor = funsor.to_funsor(
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/functools.py", line 888, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/funsor/tensor.py", line 491, in tensor_to_funsor
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Invalid shape: expected (), actual (1,)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 248, in potential_energy
log_joint, model_trace = log_density_(
File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 270, in log_density
result, model_trace, _ = _enum_log_density(
File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 181, in _enum_log_density
log_prob_factor = funsor.to_funsor(
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/functools.py", line 888, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/funsor/tensor.py", line 491, in tensor_to_funsor
raise ValueError(
ValueError: Invalid shape: expected (), actual (1,)
Thanks @rafaol! I guess we mixed up 1-size plate with singleton dimensions created by promoting and broadcasting somewhere. Let us take a closer look into this.