probability icon indicating copy to clipboard operation
probability copied to clipboard

Categorical distributions / JAX vmap Error

Open XMaster96 opened this issue 2 years ago • 0 comments

I am trying to use a Categorical distributions inside of a function with JAX vmap. I also tested pmap and jit and both are working fine. Thanks for the help.

colab notebook for reproducibility

jax==0.3.1 tfp==0.16.0

import jax
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions

@jax.vmap
def fun(logits):
    dist = tfd.Categorical(logits=logits)

    return dist
    
fun(jax.numpy.ones((4, 8, 5)))
---------------------------------------------------------------------------

UnfilteredStackTrace                      Traceback (most recent call last)

[<ipython-input-6-cf4434029c53>](https://localhost:8080/#) in <module>()
      7 
----> 8 fun(jax.numpy.ones((4, 8, 5)))

23 frames

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    164     try:
--> 165       return fun(*args, **kwargs)
    166     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in vmap_f(*args, **kwargs)
   1555         lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
-> 1556     ).call_wrapped(*args_flat)
   1557     return tree_unflatten(out_tree(), out_flat)

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    180       try:
--> 181         ans = gen.send(ans)
    182       except:

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/batching.py](https://localhost:8080/#) in _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals)
    386   outs = yield in_tracers, {}
--> 387   out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
    388   out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>()
   1554         flat_fun, axis_name, axis_size_, in_axes_flat,
-> 1555         lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
   1556     ).call_wrapped(*args_flat)

[/usr/local/lib/python3.7/dist-packages/jax/_src/api_util.py](https://localhost:8080/#) in flatten_axes(name, treedef, axis_tree, kws, tupled_args)
    284   proxy = object()
--> 285   dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
    286   axes = []

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_unflatten(treedef, leaves)
     70   """
---> 71   return treedef.unflatten(leaves)
     72 

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in unflatten(info, xs)
    391       parameters = dict(list(zip(keys, xs)), **metadata)
--> 392       return cls(**parameters)
    393     from jax import tree_util  # pylint: disable=g-import-not-at-top

<decorator-gen-290> in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in wrapped_init(***failed resolving arguments***)
    341       self_._parameters = None
--> 342       default_init(self_, *args, **kwargs)
    343       # Note: if we ever want to override things set in `self` by subclass

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/categorical.py](https://localhost:8080/#) in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)
    191       self._logits = tensor_util.convert_nonref_to_tensor(
--> 192           logits, dtype_hint=prob_logit_dtype, name='logits')
    193       super(Categorical, self).__init__(

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/internal/tensor_util.py](https://localhost:8080/#) in convert_nonref_to_tensor(value, dtype, dtype_hint, as_shape_tensor, name)
    119   return tf.convert_to_tensor(
--> 120       value, dtype=dtype, dtype_hint=dtype_hint, name=name)
    121 

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/_utils.py](https://localhost:8080/#) in wrap(***failed resolving arguments***)
     61     del instance, wrapped
---> 62     return new_fn(*args, **kwargs)
     63   return wrap(original_fn)  # pylint: disable=no-value-for-parameter

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _convert_to_tensor(value, dtype, dtype_hint, name)
    162   if ret is None:
--> 163     ret = conversion_func(value, dtype=dtype)
    164   return ret

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _default_convert_to_tensor(value, dtype)
    214   """Default tensor conversion function for array, bool, int, float, and complex."""
--> 215   inferred_dtype = _infer_dtype(value, np.float32)
    216   # When a dtype is provided, we can go ahead and try converting to the dtype

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _infer_dtype(value, default_dtype)
    192                     ' with an unsupported type ({}) to a Tensor.').format(
--> 193                         value, type(value)))
    194 

UnfilteredStackTrace: ValueError: Attempt to convert a value (<object object at 0x7fe9f8a3c7a0>) with an unsupported type (<class 'object'>) to a Tensor.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)

[<ipython-input-6-cf4434029c53>](https://localhost:8080/#) in <module>()
      6     return dist
      7 
----> 8 fun(jax.numpy.ones((4, 8, 5)))

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in unflatten(info, xs)
    390       keys, metadata = info
    391       parameters = dict(list(zip(keys, xs)), **metadata)
--> 392       return cls(**parameters)
    393     from jax import tree_util  # pylint: disable=g-import-not-at-top
    394     tree_util.register_pytree_node(cls, flatten, unflatten)

<decorator-gen-290> in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/distribution.py](https://localhost:8080/#) in wrapped_init(***failed resolving arguments***)
    340       # called, here is the place to do it.
    341       self_._parameters = None
--> 342       default_init(self_, *args, **kwargs)
    343       # Note: if we ever want to override things set in `self` by subclass
    344       # `__init__`, here is the place to do it.

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/distributions/categorical.py](https://localhost:8080/#) in __init__(self, logits, probs, dtype, validate_args, allow_nan_stats, name)
    190           probs, dtype_hint=prob_logit_dtype, name='probs')
    191       self._logits = tensor_util.convert_nonref_to_tensor(
--> 192           logits, dtype_hint=prob_logit_dtype, name='logits')
    193       super(Categorical, self).__init__(
    194           dtype=dtype,

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/substrates/jax/internal/tensor_util.py](https://localhost:8080/#) in convert_nonref_to_tensor(value, dtype, dtype_hint, as_shape_tensor, name)
    118         value, dtype=dtype, dtype_hint=dtype_hint, name=name)
    119   return tf.convert_to_tensor(
--> 120       value, dtype=dtype, dtype_hint=dtype_hint, name=name)
    121 
    122 

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/_utils.py](https://localhost:8080/#) in wrap(***failed resolving arguments***)
     60   def wrap(wrapped, instance, args, kwargs):
     61     del instance, wrapped
---> 62     return new_fn(*args, **kwargs)
     63   return wrap(original_fn)  # pylint: disable=no-value-for-parameter
     64 

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _convert_to_tensor(value, dtype, dtype_hint, name)
    161 
    162   if ret is None:
--> 163     ret = conversion_func(value, dtype=dtype)
    164   return ret
    165 

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _default_convert_to_tensor(value, dtype)
    213 def _default_convert_to_tensor(value, dtype=None):
    214   """Default tensor conversion function for array, bool, int, float, and complex."""
--> 215   inferred_dtype = _infer_dtype(value, np.float32)
    216   # When a dtype is provided, we can go ahead and try converting to the dtype
    217   # and force overflow/underflow if an int64 is converted to an int32.

[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/internal/backend/jax/ops.py](https://localhost:8080/#) in _infer_dtype(value, default_dtype)
    191   raise ValueError(('Attempt to convert a value ({})'
    192                     ' with an unsupported type ({}) to a Tensor.').format(
--> 193                         value, type(value)))
    194 
    195 

ValueError: Attempt to convert a value (<object object at 0x7fe9f8a3c7a0>) with an unsupported type (<class 'object'>) to a Tensor.

XMaster96 avatar Mar 03 '22 11:03 XMaster96