jax
jax copied to clipboard
jax.random.randint silently ignores dtype of int64 instead of generating warning
Description
import jax.numpy as jnp
import jax.random as jr
k = jr.PRNGKey(0)
a = jr.randint(k, (10,), minval=10, maxval=20, dtype=jnp.int64) # no warning is printed!
assert a.dtype == jnp.int32
a = a.astype(jnp.int64) # warning is printed
What jax/jaxlib version are you using?
3.17
Which accelerator(s) are you using?
CPU
Additional System Info
macos
Thanks for the report - it looks like we forgot to call _check_user_dtype_supported.
It looks like this has been fixed.