numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

initialize_model() fails with batch size of 1 for model with discrete variables

Open rafaol opened this issue 3 years ago • 1 comments

Hi,

I'm having issues using NumPyro's latest version of initialize_model() on a model with discrete variables. The model samples random variables within a plate context with a data-dependent size. When the given data is of size $N > 1$, everything works fine. However, if I pass a dataset with a single data point ($N = 1$), the execution of initialize_model fails. Apparently, the funsor API expected the computed log-probability of a sampled variable to be a scalar, when in fact it's a single-element 1D array for this case.

Minimal working example:

import jax
import jax.numpy as jnp
import numpyro
from numpyro.distributions import Bernoulli, MultivariateNormal
from numpyro.infer.util import initialize_model


def model(X: jnp.ndarray, y: jnp.ndarray, z: jnp.ndarray):
    phi = numpyro.sample('phi', MultivariateNormal(jnp.zeros(2), jnp.eye(2)))

    n_data = X.shape[-2]
    with numpyro.plate('individual', n_data):
        weights = jnp.tensordot(X, phi, axes=1)
        numpyro.sample('labels', Bernoulli(probs=(1/(1 + jnp.exp(-weights)))), obs=z, infer={'enumerate': 'parallel'})
        coefficients = numpyro.sample('coefficients', MultivariateNormal(jnp.zeros(2), jnp.eye(2)))
        numpyro.sample('responses', MultivariateNormal(X * coefficients, jnp.eye(2)), obs=y)


if __name__ == '__main__':
    rng_key = jax.random.PRNGKey(0)

    # This line runs without problems:
    res = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((10, 2)), 'z': None, 'y': None})

    # This line fails
    res1 = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((1, 2)), 'z': None, 'y': None})

    print('Done')

Execution output:

Traceback (most recent call last):
  File "/home/rafael/Projects/trajectory-clustering/funsors_issue.py", line 26, in <module>
    res1 = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((1, 2)), 'z': None, 'y': None})
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 654, in initialize_model
    (init_params, pe, grad), is_valid = find_valid_initial_params(
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 395, in find_valid_initial_params
    (init_params, pe, z_grad), is_valid = _find_valid_params(
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 381, in _find_valid_params
    _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 366, in body_fn
    pe, z_grad = value_and_grad(potential_fn)(params)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/api.py", line 1063, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/api.py", line 2558, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/ad.py", line 133, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/ad.py", line 122, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/profiler.py", line 312, in wrapper
    return func(*args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 621, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 248, in potential_energy
    log_joint, model_trace = log_density_(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 270, in log_density
    result, model_trace, _ = _enum_log_density(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 181, in _enum_log_density
    log_prob_factor = funsor.to_funsor(
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/funsor/tensor.py", line 491, in tensor_to_funsor
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Invalid shape: expected (), actual (1,)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 248, in potential_energy
    log_joint, model_trace = log_density_(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 270, in log_density
    result, model_trace, _ = _enum_log_density(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 181, in _enum_log_density
    log_prob_factor = funsor.to_funsor(
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/funsor/tensor.py", line 491, in tensor_to_funsor
    raise ValueError(
ValueError: Invalid shape: expected (), actual (1,)

rafaol avatar Jul 11 '22 03:07 rafaol

Thanks @rafaol! I guess we mixed up 1-size plate with singleton dimensions created by promoting and broadcasting somewhere. Let us take a closer look into this.

fehiepsi avatar Jul 17 '22 21:07 fehiepsi