probability icon indicating copy to clipboard operation
probability copied to clipboard

MultivariateNormal* constructor crashes with Numpy 2.0

Open slinderman opened this issue 1 year ago • 2 comments

I updated to Numpy 2.0 and found that the MultivariateNormalDiag and MultivariateNormalFullCovariance constructors crashed because np.issctype has been removed. Is Numpy 2.0 supported, or will it be soon?

Here is a simple repro:

import numpy as np
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

print(f"numpy version: {np.__version__}")
print(f"jax version: {jax.__version__}")
print(f"tfp verison: {tfp.__version__}")

# works fine
nml = tfd.Normal(jnp.zeros(3), jnp.ones(3))

# fails with numpy 2.0
mvn = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))

# also fails with same error numpy 2.0
# mvn = tfd.MultivariateNormalFullCovariance(jnp.zeros(3), jnp.eye(3))

On my machine with Python 3.10, it produces the following output:

numpy version: 2.0.0
jax version: 0.4.30
tfp verison: 0.24.0
Traceback (most recent call last):
  File "/Users/scott/Projects/dynamax/tfp_debug_20240618.py", line 15, in <module>
    mvn = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 209, in __init__
    super(MultivariateNormalDiag, self).__init__(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py", line 205, in __init__
    super(MultivariateNormalLinearOperator, self).__init__(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 244, in __init__
    dtype = self.bijector.forward_dtype(self.distribution.dtype)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1705, in forward_dtype
    input_dtype = nest.map_structure_up_to(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 324, in map_structure_up_to
    return map_structure_with_tuple_paths_up_to(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 353, in map_structure_with_tuple_paths_up_to
    return dm_tree.map_structure_with_path_up_to(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tree/__init__.py", line 778, in map_structure_with_path_up_to
    results.append(func(*path_and_values))
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 326, in <lambda>
    lambda _, *args: func(*args),  # Discards path.
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1707, in <lambda>
    lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py", line 247, in convert_to_dtype
    elif np.issctype(tensor_or_dtype):
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/numpy/__init__.py", line 397, in __getattr__
    raise AttributeError(
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.. Did you mean: 'isdtype'?

slinderman avatar Jun 18 '24 18:06 slinderman

Got the same error when using tfd.MultivariateNormalTriL()

giladturok avatar Jul 19 '24 21:07 giladturok

Is there any movement on this issue? The same issue is preventing us unpinning numpy

thomaspinder avatar Sep 12 '24 07:09 thomaspinder

I checked again and this issue seems to be resolved in TFP version 0.25.0.

slinderman avatar May 10 '25 11:05 slinderman