Cant slice along batch for MixtureSameDistribution produced by Keras layer
Not sure if this is Keras bug or tfp bug.
I'm trying to make some dense layers that output the parameters of Gaussian mixture (a mixture density network). I want to run a batch of data through the network (for speed), get out a batch of distributions, and then slice to work with only some elements of the batch at a time. If I were doing this with just tfp, I call:
import tensorflow_probability as tfp
tfd = tfp.distributions
num_mixture_components = 12
batch_size = 4
probs = np.random.rand(batch_size, num_mixture_components)
loc = np.random.rand(batch_size, num_mixture_components)
scale = np.random.rand(batch_size, num_mixture_components)
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=probs),
components_distribution=tfd.Normal(loc=loc, scale=scale))
print(gm)
# slice along batch
print(gm[1])
This works as expected giving
tfp.distributions.MixtureSameFamily("MixtureSameFamily", batch_shape=[4], event_shape=[], dtype=float64)
tfp.distributions.MixtureSameFamily("MixtureSameFamily", batch_shape=[], event_shape=[], dtype=float64)
However when I try the same thing with a mixture density network in Keras I get an error
import tensorflow.keras.layers as tfkl
import tensorflow.keras as tfk
num_mixture_components = 12
l = tfkl.Input(shape=(100))
# Make a fully connected network that outputs parameters of Gaussian mixture
mu = tfkl.Dense(units=num_mixture_components, activation=None)(l)
sigma = tfkl.Dense(units=num_mixture_components, activation='softplus')(l)
alpha = tfkl.Dense(units=num_mixture_components, activation='softmax')(l)
stacked = tfkl.Concatenate()([mu, sigma, alpha])
mixture = tfp.layers.MixtureNormal(num_mixture_components,
event_shape=[], name="test")(stacked)
model = tf.keras.Model(inputs=l, outputs=mixture)
out = model(np.random.rand(4, 100))
gm = out.tensor_distribution;
print(gm)
# slice along batch
print(gm[1])
Gives me this cryptic error:
tfp.distributions._MixtureSameFamily("MixtureSameFamily", batch_shape=[4], event_shape=[], dtype=float32)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [72], line 27
25 print(gm)
26 # slice along batch
---> 27 print(gm[1])
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:852, in Distribution.__getitem__(self, slices)
825 def __getitem__(self, slices):
826 """Slices the batch axes of this distribution, returning a new instance.
827
828 ```python
(...)
850 dist: A new `tfd.Distribution` instance with sliced parameters.
851 """
--> 852 return slicing.batch_slice(self, {}, slices)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:220, in batch_slice(batch_object, params_overrides, slices, bijector_x_event_ndims)
217 slice_overrides_seq = slice_overrides_seq + [(slices, params_overrides)]
218 # Re-doing the full sequence of slice+copy override work here enables
219 # gradients all the way back to the original batch_objectribution's arguments.
--> 220 batch_object = _apply_slice_sequence(
221 orig_batch_object,
222 slice_overrides_seq,
223 bijector_x_event_ndims=bijector_x_event_ndims)
224 setattr(batch_object,
225 PROVENANCE_ATTR,
226 batch_object._no_dependency((orig_batch_object, slice_overrides_seq))) # pylint: disable=protected-access
227 return batch_object
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:179, in _apply_slice_sequence(batch_object, slice_overrides_seq, bijector_x_event_ndims)
177 """Applies a sequence of slice or copy-with-overrides operations to `batch_object`."""
178 for slices, overrides in slice_overrides_seq:
--> 179 batch_object = _apply_single_step(
180 batch_object,
181 slices,
182 overrides,
183 bijector_x_event_ndims=bijector_x_event_ndims)
184 return batch_object
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:168, in _apply_single_step(batch_object, slices, params_overrides, bijector_x_event_ndims)
166 override_dict = {}
167 else:
--> 168 override_dict = _slice_params_to_dict(
169 batch_object, slices, bijector_x_event_ndims=bijector_x_event_ndims)
170 override_dict.update(params_overrides)
171 parameters = dict(batch_object.parameters, **override_dict)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:153, in _slice_params_to_dict(batch_object, slices, bijector_x_event_ndims)
150 else:
151 batch_shape = batch_object.experimental_batch_shape_tensor(
152 x_event_ndims=bijector_x_event_ndims)
--> 153 return batch_shape_lib.map_fn_over_parameters_with_event_ndims(
154 batch_object,
155 functools.partial(_slice_single_param,
156 slices=slices,
157 batch_shape=batch_shape),
158 bijector_x_event_ndims=bijector_x_event_ndims)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/batch_shape_lib.py:367, in map_fn_over_parameters_with_event_ndims(batch_object, fn, bijector_x_event_ndims, require_static, **parameter_kwargs)
361 elif (properties.is_tensor
362 and not tf.is_tensor(param)
363 and not tf.nest.is_nested(param_event_ndims)):
364 # As a last resort, try an explicit conversion.
365 param = tensor_util.convert_nonref_to_tensor(param, name=param_name)
--> 367 results[param_name] = nest.map_structure_up_to(
368 param, fn, param, param_event_ndims)
369 return results
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1435, in map_structure_up_to(shallow_tree, func, *inputs, **kwargs)
1361 @tf_export("__internal__.nest.map_structure_up_to", v1=[])
1362 def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
1363 """Applies a function or op to a number of partially flattened inputs.
1364
1365 The `inputs` are flattened up to `shallow_tree` before being mapped.
(...)
1433 `shallow_tree`.
1434 """
-> 1435 return map_structure_with_tuple_paths_up_to(
1436 shallow_tree,
1437 lambda _, *values: func(*values), # Discards the path arg.
1438 *inputs,
1439 **kwargs)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1535, in map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs)
1526 flat_value_gen = (
1527 flatten_up_to( # pylint: disable=g-complex-comprehension
1528 shallow_tree,
1529 input_tree,
1530 check_types,
1531 expand_composites=expand_composites) for input_tree in inputs)
1532 flat_path_gen = (
1533 path
1534 for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
-> 1535 results = [
1536 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
1537 ]
1538 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
1539 expand_composites=expand_composites)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1536, in <listcomp>(.0)
1526 flat_value_gen = (
1527 flatten_up_to( # pylint: disable=g-complex-comprehension
1528 shallow_tree,
1529 input_tree,
1530 check_types,
1531 expand_composites=expand_composites) for input_tree in inputs)
1532 flat_path_gen = (
1533 path
1534 for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
1535 results = [
-> 1536 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
1537 ]
1538 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
1539 expand_composites=expand_composites)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/nest.py:1437, in map_structure_up_to.<locals>.<lambda>(_, *values)
1361 @tf_export("__internal__.nest.map_structure_up_to", v1=[])
1362 def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
1363 """Applies a function or op to a number of partially flattened inputs.
1364
1365 The `inputs` are flattened up to `shallow_tree` before being mapped.
(...)
1433 `shallow_tree`.
1434 """
1435 return map_structure_with_tuple_paths_up_to(
1436 shallow_tree,
-> 1437 lambda _, *values: func(*values), # Discards the path arg.
1438 *inputs,
1439 **kwargs)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:101, in _slice_single_param(param, param_event_ndims, slices, batch_shape)
85 """Slices into the batch shape of a single parameter.
86
87 Args:
(...)
98 `slices`.
99 """
100 # Broadcast the parmameter to have full batch rank.
--> 101 param = batch_shape_lib.broadcast_parameter_with_batch_shape(
102 param, param_event_ndims, ps.ones_like(batch_shape))
103 param_batch_shape = batch_shape_lib.get_batch_shape_tensor_part(
104 param, param_event_ndims)
105 # At this point the param should have full batch rank, *unless* it's an
106 # atomic object like `tfb.Identity()` incapable of having any batch rank.
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/batch_shape_lib.py:274, in broadcast_parameter_with_batch_shape(param, param_event_ndims, batch_shape)
270 base_shape = ps.concat([batch_shape,
271 ps.ones([param_event_ndims], dtype=np.int32)],
272 axis=0)
273 if hasattr(param, '_broadcast_parameters_with_batch_shape'):
--> 274 return param._broadcast_parameters_with_batch_shape(base_shape) # pylint: disable=protected-access
275 elif hasattr(param, 'matmul'):
276 # TODO(davmre): support broadcasting LinearOperator parameters.
277 return param
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:952, in Distribution._broadcast_parameters_with_batch_shape(self, batch_shape)
926 def _broadcast_parameters_with_batch_shape(self, batch_shape):
927 """Broadcasts each parameter's batch shape with the given `batch_shape`.
928
929 This is semantically equivalent to wrapping with the `BatchBroadcast`
(...)
950 the given `batch_shape`.
951 """
--> 952 return self.copy(
953 **batch_shape_lib.broadcast_parameters_with_batch_shape(
954 self, batch_shape))
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:915, in Distribution.copy(self, **override_parameters_kwargs)
897 """Creates a deep copy of the distribution.
898
899 Note: the copy distribution may continue to depend on the original
(...)
909 `dict(self.parameters, **override_parameters_kwargs)`.
910 """
911 try:
912 # We want track provenance from origin variables, so we use batch_slice
913 # if this distribution supports slicing. See the comment on
914 # PROVENANCE_ATTR in batch_slicing.py
--> 915 return slicing.batch_slice(self, override_parameters_kwargs, Ellipsis)
916 except NotImplementedError:
917 pass
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:220, in batch_slice(batch_object, params_overrides, slices, bijector_x_event_ndims)
217 slice_overrides_seq = slice_overrides_seq + [(slices, params_overrides)]
218 # Re-doing the full sequence of slice+copy override work here enables
219 # gradients all the way back to the original batch_objectribution's arguments.
--> 220 batch_object = _apply_slice_sequence(
221 orig_batch_object,
222 slice_overrides_seq,
223 bijector_x_event_ndims=bijector_x_event_ndims)
224 setattr(batch_object,
225 PROVENANCE_ATTR,
226 batch_object._no_dependency((orig_batch_object, slice_overrides_seq))) # pylint: disable=protected-access
227 return batch_object
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:179, in _apply_slice_sequence(batch_object, slice_overrides_seq, bijector_x_event_ndims)
177 """Applies a sequence of slice or copy-with-overrides operations to `batch_object`."""
178 for slices, overrides in slice_overrides_seq:
--> 179 batch_object = _apply_single_step(
180 batch_object,
181 slices,
182 overrides,
183 bijector_x_event_ndims=bijector_x_event_ndims)
184 return batch_object
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/slicing.py:172, in _apply_single_step(batch_object, slices, params_overrides, bijector_x_event_ndims)
170 override_dict.update(params_overrides)
171 parameters = dict(batch_object.parameters, **override_dict)
--> 172 return type(batch_object)(**parameters)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/decorator.py:231, in decorate.<locals>.fun(*args, **kw)
229 def fun(*args, **kw):
230 if not kwsyntax:
--> 231 args, kw = fix(args, kw, sig)
232 return caller(func, *(extras + args), **kw)
File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/decorator.py:203, in fix(args, kwargs, sig)
199 def fix(args, kwargs, sig):
200 """
201 Fix args and kwargs to be consistent with the signature
202 """
--> 203 ba = sig.bind(*args, **kwargs)
204 ba.apply_defaults() # needed for test_dan_schult
205 return ba.args, ba.kwargs
File ~/mambaforge/envs/phenotypes/lib/python3.10/inspect.py:3179, in Signature.bind(self, *args, **kwargs)
3174 def bind(self, /, *args, **kwargs):
3175 """Get a BoundArguments object, that maps the passed `args`
3176 and `kwargs` to the function's signature. Raises `TypeError`
3177 if the passed arguments can not be bound.
3178 """
-> 3179 return self._bind(args, kwargs)
File ~/mambaforge/envs/phenotypes/lib/python3.10/inspect.py:3168, in Signature._bind(self, args, kwargs, partial)
3166 arguments[kwargs_param.name] = kwargs
3167 else:
-> 3168 raise TypeError(
3169 'got an unexpected keyword argument {arg!r}'.format(
3170 arg=next(iter(kwargs))))
3172 return self._bound_arguments_cls(self, arguments)
TypeError: got an unexpected keyword argument 'reinterpreted_batch_ndims'
Using tensorflow 2.11.0 tensorflow-probability 0.19.0
Any ideas why this is happening and how to fix?
I am also running into this same issue. Did you ever figure out what the problem was @henrypinkard ?
No, but replacing with a DistributionLambda mostly fixed it:
mixture = tfpl.DistributionLambda(
make_distribution_fn=lambda params:
tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical( probs=params[0]),
components_distribution=tfd.Normal(loc=params[1], scale=params[2])),
convert_to_tensor_fn=tfd.Distribution.sample)([alpha, mu, sigma])
Unclear to me what the point of having special classes for distribution layers is when you can just do this instead...