probability
probability copied to clipboard
Mixture of tuple-valued distributions for TFP on JAX raises a TypeError
Short description
The Mixture distribution does not seem to be compatible with tuple-valued distributions, raising an error:
TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'
Such distributions appear e.g.,:
- As a mixture of several models defined by
JointDistributionSequential. - As a mixture of distributions defined by using
Splitbijector.
I'm not sure where the issue lies, but under some guidance I'd be more than happy to work on a fix! :slightly_smiling_face:
Code example
Consider three structured distributions. Each of them returns a tuple (int, float):
from tensorflow_probability.substrates import jax as tfp
import jax
import jax.numpy as jnp
tfd = tfp.distributions
dist1 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.1), lambda x: tfd.Normal(0.0 + x, 0.2)])
dist2 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.3), lambda x: tfd.Normal(1.0 + x, 0.2)])
dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])
probs = jnp.asarray([0.05, 0.1, 0.85])
mixture = tfd.Mixture(
cat=tfd.Categorical(probs=probs),
components=[dist1, dist2, dist3],
)
Version
I used Python 3.11 and TFP 0.20.1. I then reproduced the behavior using the development version, 0.25.0.dev20240628.
Full error message
TypeError Traceback (most recent call last)
Cell In[5], line 14
10 dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])
12 probs = jnp.asarray([0.05, 0.1, 0.85])
---> 14 mixture = tfd.Mixture(
15 cat=tfd.Categorical(probs=probs),
16 components=[dist1, dist2, dist3],
17 )
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py:232](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py#line=231), in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py#line=341), in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py:155](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py#line=154), in _Mixture.__init__(self, cat, components, validate_args, allow_nan_stats, name)
153 static_batch_shape = cat.batch_shape
154 for di, d in enumerate(components):
--> 155 if not tensorshape_util.is_compatible_with(static_batch_shape,
156 d.batch_shape):
157 raise ValueError(
158 'components[{}] batch shape must be compatible with cat '
159 'shape and other component batch shapes ({} vs {})'.format(
160 di, static_batch_shape, d.batch_shape))
161 if not tensorshape_util.is_compatible_with(static_event_shape,
162 d.event_shape):
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py:211](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py#line=210), in is_compatible_with(x, other)
199 def is_compatible_with(x, other):
200 """Returns `True` iff `x` is compatible with `other`.
201
202 For more details, see `help(tf.TensorShape.is_compatible_with)`.
(...)
209 is_compatible: `bool` indicating of the shapes are compatible.
210 """
--> 211 return tf.TensorShape(x).is_compatible_with(other)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1437](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1436), in TensorShape.is_compatible_with(self, other)
1399 def is_compatible_with(self, other):
1400 """Returns True iff `self` is compatible with `other`.
1401
1402 Two possibly-partially-defined shapes are compatible if there
(...)
1435
1436 """
-> 1437 other = as_shape(other)
1438 if self.dims is not None and other.dims is not None:
1439 if self.rank != other.rank:
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1624](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1623), in as_shape(shape)
1622 return shape
1623 else:
-> 1624 return TensorShape(shape)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in TensorShape.__init__(self, dims)
896 """Creates a new TensorShape with the given dimensions.
897
898 Args:
(...)
902 TypeError: If dims cannot be converted to a list of dimensions.
903 """
904 if isinstance(dims, (tuple, list)): # Most common case.
--> 905 self._dims = tuple(as_dimension(d).value for d in dims)
906 elif dims is None:
907 self._dims = None
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in <genexpr>(.0)
896 """Creates a new TensorShape with the given dimensions.
897
898 Args:
(...)
902 TypeError: If dims cannot be converted to a list of dimensions.
903 """
904 if isinstance(dims, (tuple, list)): # Most common case.
--> 905 self._dims = tuple(as_dimension(d).value for d in dims)
906 elif dims is None:
907 self._dims = None
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:819](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=818), in as_dimension(value)
817 return value
818 else:
--> 819 return Dimension(value)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:295](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=294), in Dimension.__init__(self, value)
293 self._value = int(value.__index__())
294 except AttributeError:
--> 295 raise TypeError(
296 "Dimension value must be integer or None or have "
297 "an __index__ method, got value '{0!r}' with type '{1!r}'".format(
298 value, type(value))) from None
299 if self._value < 0:
300 raise ValueError("Dimension %d must be >= 0" % self._value)
TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'