probability icon indicating copy to clipboard operation
probability copied to clipboard

Mixture of tuple-valued distributions for TFP on JAX raises a TypeError

Open pawel-czyz opened this issue 1 year ago • 0 comments

Short description

The Mixture distribution does not seem to be compatible with tuple-valued distributions, raising an error:

TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'

Such distributions appear e.g.,:

I'm not sure where the issue lies, but under some guidance I'd be more than happy to work on a fix! :slightly_smiling_face:

Code example

Consider three structured distributions. Each of them returns a tuple (int, float):

from tensorflow_probability.substrates import jax as tfp

import jax
import jax.numpy as jnp

tfd = tfp.distributions

dist1 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.1), lambda x: tfd.Normal(0.0 + x, 0.2)])
dist2 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.3), lambda x: tfd.Normal(1.0 + x, 0.2)])
dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])

probs = jnp.asarray([0.05, 0.1, 0.85])

mixture = tfd.Mixture(
    cat=tfd.Categorical(probs=probs),
    components=[dist1, dist2, dist3],
)

Version

I used Python 3.11 and TFP 0.20.1. I then reproduced the behavior using the development version, 0.25.0.dev20240628.

Full error message

TypeError                                 Traceback (most recent call last)
Cell In[5], line 14
     10 dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])
     12 probs = jnp.asarray([0.05, 0.1, 0.85])
---> 14 mixture = tfd.Mixture(
     15     cat=tfd.Categorical(probs=probs),
     16     components=[dist1, dist2, dist3],
     17 )

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py:232](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py#line=231), in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py#line=341), in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py:155](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py#line=154), in _Mixture.__init__(self, cat, components, validate_args, allow_nan_stats, name)
    153 static_batch_shape = cat.batch_shape
    154 for di, d in enumerate(components):
--> 155   if not tensorshape_util.is_compatible_with(static_batch_shape,
    156                                              d.batch_shape):
    157     raise ValueError(
    158         'components[{}] batch shape must be compatible with cat '
    159         'shape and other component batch shapes ({} vs {})'.format(
    160             di, static_batch_shape, d.batch_shape))
    161   if not tensorshape_util.is_compatible_with(static_event_shape,
    162                                              d.event_shape):

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py:211](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py#line=210), in is_compatible_with(x, other)
    199 def is_compatible_with(x, other):
    200   """Returns `True` iff `x` is compatible with `other`.
    201 
    202   For more details, see `help(tf.TensorShape.is_compatible_with)`.
   (...)
    209     is_compatible: `bool` indicating of the shapes are compatible.
    210   """
--> 211   return tf.TensorShape(x).is_compatible_with(other)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1437](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1436), in TensorShape.is_compatible_with(self, other)
   1399 def is_compatible_with(self, other):
   1400   """Returns True iff `self` is compatible with `other`.
   1401 
   1402   Two possibly-partially-defined shapes are compatible if there
   (...)
   1435 
   1436   """
-> 1437   other = as_shape(other)
   1438   if self.dims is not None and other.dims is not None:
   1439     if self.rank != other.rank:

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1624](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1623), in as_shape(shape)
   1622   return shape
   1623 else:
-> 1624   return TensorShape(shape)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in TensorShape.__init__(self, dims)
    896 """Creates a new TensorShape with the given dimensions.
    897 
    898 Args:
   (...)
    902   TypeError: If dims cannot be converted to a list of dimensions.
    903 """
    904 if isinstance(dims, (tuple, list)):  # Most common case.
--> 905   self._dims = tuple(as_dimension(d).value for d in dims)
    906 elif dims is None:
    907   self._dims = None

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in <genexpr>(.0)
    896 """Creates a new TensorShape with the given dimensions.
    897 
    898 Args:
   (...)
    902   TypeError: If dims cannot be converted to a list of dimensions.
    903 """
    904 if isinstance(dims, (tuple, list)):  # Most common case.
--> 905   self._dims = tuple(as_dimension(d).value for d in dims)
    906 elif dims is None:
    907   self._dims = None

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:819](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=818), in as_dimension(value)
    817   return value
    818 else:
--> 819   return Dimension(value)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:295](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=294), in Dimension.__init__(self, value)
    293   self._value = int(value.__index__())
    294 except AttributeError:
--> 295   raise TypeError(
    296       "Dimension value must be integer or None or have "
    297       "an __index__ method, got value '{0!r}' with type '{1!r}'".format(
    298           value, type(value))) from None
    299 if self._value < 0:
    300   raise ValueError("Dimension %d must be >= 0" % self._value)

TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'

pawel-czyz avatar Jun 28 '24 21:06 pawel-czyz