probability
probability copied to clipboard
Categorical distributions / JAX vmap Error
I am trying to use a Categorical distributions inside of a function with JAX vmap. I also tested pmap and jit and both are working fine. Thanks for the help.
colab notebook for reproducibility
jax==0.3.1 tfp==0.16.0
import jax
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
@jax.vmap
def fun(logits):
dist = tfd.Categorical(logits=logits)
return dist
fun(jax.numpy.ones((4, 8, 5)))
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-6-cf4434029c53>](https://localhost:8080/#) in <module>()
7
----> 8 fun(jax.numpy.ones((4, 8, 5)))
23 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
164 try:
--> 165 return fun(*args, **kwargs)
166 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in vmap_f(*args, **kwargs)
1555 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
-> 1556 ).call_wrapped(*args_flat)
1557 return tree_unflatten(out_tree(), out_flat)
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
180 try:
--> 181 ans = gen.send(ans)
182 except:
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/batching.py](https://localhost:8080/#) in _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals)
386 outs = yield in_tracers, {}
--> 387 out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
388 out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>()
1554 flat_fun, axis_name, axis_size_, in_axes_flat,
-> 1555 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
1556 ).call_wrapped(*args_flat)
[/usr/local/lib/python3.7/dist-packages/jax/_src/api_util.py](https://localhost:8080/#) in flatten_axes(name, treedef, axis_tree, kws, tupled_args)
284 proxy = object()
--> 285 dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
286 axes = []
[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_unflatten(treedef, leaves)
70 """
---> 71 return treedef.unflatten(leaves)
72
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in unflatten(info, xs)
391 parameters = dict(list(zip(keys, xs)), **metadata)
--> 392 return cls(**parameters)
393 from jax import tree_util # pylint: disable=g-import-not-at-top
<decorator-gen-290> in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in wrapped_init(***failed resolving arguments***)
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/categorical.py](https://localhost:8080/#) in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)
191 self._logits = tensor_util.convert_nonref_to_tensor(
--> 192 logits, dtype_hint=prob_logit_dtype, name='logits')
193 super(Categorical, self).__init__(
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/internal/tensor_util.py](https://localhost:8080/#) in convert_nonref_to_tensor(value, dtype, dtype_hint, as_shape_tensor, name)
119 return tf.convert_to_tensor(
--> 120 value, dtype=dtype, dtype_hint=dtype_hint, name=name)
121
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/_utils.py](https://localhost:8080/#) in wrap(***failed resolving arguments***)
61 del instance, wrapped
---> 62 return new_fn(*args, **kwargs)
63 return wrap(original_fn) # pylint: disable=no-value-for-parameter
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _convert_to_tensor(value, dtype, dtype_hint, name)
162 if ret is None:
--> 163 ret = conversion_func(value, dtype=dtype)
164 return ret
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _default_convert_to_tensor(value, dtype)
214 """Default tensor conversion function for array, bool, int, float, and complex."""
--> 215 inferred_dtype = _infer_dtype(value, np.float32)
216 # When a dtype is provided, we can go ahead and try converting to the dtype
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _infer_dtype(value, default_dtype)
192 ' with an unsupported type ({}) to a Tensor.').format(
--> 193 value, type(value)))
194
UnfilteredStackTrace: ValueError: Attempt to convert a value (<object object at 0x7fe9f8a3c7a0>) with an unsupported type (<class 'object'>) to a Tensor.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
[<ipython-input-6-cf4434029c53>](https://localhost:8080/#) in <module>()
6 return dist
7
----> 8 fun(jax.numpy.ones((4, 8, 5)))
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in unflatten(info, xs)
390 keys, metadata = info
391 parameters = dict(list(zip(keys, xs)), **metadata)
--> 392 return cls(**parameters)
393 from jax import tree_util # pylint: disable=g-import-not-at-top
394 tree_util.register_pytree_node(cls, flatten, unflatten)
<decorator-gen-290> in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in wrapped_init(***failed resolving arguments***)
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/categorical.py](https://localhost:8080/#) in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)
190 probs, dtype_hint=prob_logit_dtype, name='probs')
191 self._logits = tensor_util.convert_nonref_to_tensor(
--> 192 logits, dtype_hint=prob_logit_dtype, name='logits')
193 super(Categorical, self).__init__(
194 dtype=dtype,
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/internal/tensor_util.py](https://localhost:8080/#) in convert_nonref_to_tensor(value, dtype, dtype_hint, as_shape_tensor, name)
118 value, dtype=dtype, dtype_hint=dtype_hint, name=name)
119 return tf.convert_to_tensor(
--> 120 value, dtype=dtype, dtype_hint=dtype_hint, name=name)
121
122
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/_utils.py](https://localhost:8080/#) in wrap(***failed resolving arguments***)
60 def wrap(wrapped, instance, args, kwargs):
61 del instance, wrapped
---> 62 return new_fn(*args, **kwargs)
63 return wrap(original_fn) # pylint: disable=no-value-for-parameter
64
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _convert_to_tensor(value, dtype, dtype_hint, name)
161
162 if ret is None:
--> 163 ret = conversion_func(value, dtype=dtype)
164 return ret
165
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _default_convert_to_tensor(value, dtype)
213 def _default_convert_to_tensor(value, dtype=None):
214 """Default tensor conversion function for array, bool, int, float, and complex."""
--> 215 inferred_dtype = _infer_dtype(value, np.float32)
216 # When a dtype is provided, we can go ahead and try converting to the dtype
217 # and force overflow/underflow if an int64 is converted to an int32.
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _infer_dtype(value, default_dtype)
191 raise ValueError(('Attempt to convert a value ({})'
192 ' with an unsupported type ({}) to a Tensor.').format(
--> 193 value, type(value)))
194
195
ValueError: Attempt to convert a value (<object object at 0x7fe9f8a3c7a0>) with an unsupported type (<class 'object'>) to a Tensor.