probability
probability copied to clipboard
Struggles with Gaussian Mixture Models in jax port of tensorflow_distributions
I am getting the strangest bugs when trying to make a Guassian mixture model class in the jax substrate of tfd, has anyone experienced this before or know what the correct course of action is? Basically, I either get an annoying error message and the right length of a sample, or no error message but the wrong length of a sample.
When I try this code
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
class GMM(tfd.MixtureSameFamily):
def __init__(self, locs, log_scales, log_weights):
self.locs = locs
self.log_scales = log_scales
self.log_weights = log_weights
mixture_dist = tfd.Categorical(logits=log_weights)
component_dist = tfd.Normal(loc=locs, scale=jnp.exp(log_scales))
super().__init__(mixture_dist, component_dist)
gmm = GMM(jnp.array([0.,1.]), jnp.array([0.,1.]), jnp.array([0.,0.]))
gmm.sample( sample_shape=(), seed=jr.PRNGKey(0))
I get this error message
WARNING:root:
Distribution subclass GMM inherits _parameter_properties from its parent (MixtureSameFamily) while also redefining
init. The inherited annotations cover the following parameters: dict_keys(['mixture_distribution', 'components_distribution']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling
.copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from a
tf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new
_parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values as
self._parametersin
init`, after
calling the superclass constructor:
class MySubclass(tfd.SomeDistribution):
def __init__(self, param_a, param_b):
parameters = dict(locals())
# ... do subclass initialization ...
super(MySubclass, self).__init__(**base_class_params)
# Ensure that the subclass (not base class) parameters are stored.
self._parameters = parameters
def _parameter_properties(self, dtype, num_classes=None):
return dict(
# Annotations may optionally specify properties, such as `event_ndims`,
# `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
# the `ParameterProperties` documentation for details.
param_a=tfp.util.ParameterProperties(),
param_b=tfp.util.ParameterProperties())
WARNING:root:
Distribution subclass GMM inherits _parameter_properties from its parent (MixtureSameFamily) while also redefining
init. The inherited annotations cover the following parameters: dict_keys(['mixture_distribution', 'components_distribution']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling
.copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from a
tf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new
_parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values as
self._parametersin
init`, after
calling the superclass constructor:
class MySubclass(tfd.SomeDistribution):
def __init__(self, param_a, param_b):
parameters = dict(locals())
# ... do subclass initialization ...
super(MySubclass, self).__init__(**base_class_params)
# Ensure that the subclass (not base class) parameters are stored.
self._parameters = parameters
def _parameter_properties(self, dtype, num_classes=None):
return dict(
# Annotations may optionally specify properties, such as `event_ndims`,
# `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
# the `ParameterProperties` documentation for details.
param_a=tfp.util.ParameterProperties(),
param_b=tfp.util.ParameterProperties())
Array(1.0883901, dtype=float32) but at least only one scalar gets sampled.
When I then try to follow the recommendations of the error message with this code
class GMM(tfd.MixtureSameFamily):
def __init__(self, locs, log_scales, log_weights):
parameters = dict(locals())
self.locs = locs
self.log_scales = log_scales
self.log_weights = log_weights
mixture_dist = tfd.Categorical(logits=log_weights)
component_dist = tfd.Normal(loc=locs, scale=jnp.exp(log_scales))
super().__init__(mixture_dist, component_dist)
self._parameters = parameters
def _parameter_properties(self, dtype=jnp.float32, num_classes=None):
return dict(
locs=tfp.util.ParameterProperties(),
log_scales=tfp.util.ParameterProperties(),
log_weights=tfp.util.ParameterProperties())
gmm = GMM(jnp.array([0.,1.]), jnp.array([0.,1.]), jnp.array([0.,0.]))
gmm.sample( sample_shape=(), seed=jr.PRNGKey(0))
The error message disappears, but now I get a sample from the two different mixtures, which is not what I want! Array([ 1.85066 , -2.4407113], dtype=float32)
It seems related to these issues:
- https://github.com/tensorflow/agents/issues/658
- https://github.com/tensorflow/probability/issues/1458
You might need tfp.util.ParameterProperties(event_ndims=1) for all of your parameters. It's awkwardly named, but basically indicates how many final dimensions of each parameter get consumed to produce a single event.
hey @brianwa84 could you please assign me the isssue