jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.random.randint silently ignores dtype of int64 instead of generating warning

Open dlwh opened this issue 3 years ago • 1 comments

Description

import jax.numpy as jnp
import jax.random as jr
k = jr.PRNGKey(0)
a = jr.randint(k, (10,), minval=10, maxval=20, dtype=jnp.int64)  # no warning is printed!
assert a.dtype == jnp.int32
a = a.astype(jnp.int64)  # warning is printed

What jax/jaxlib version are you using?

3.17

Which accelerator(s) are you using?

CPU

Additional System Info

macos

dlwh avatar Sep 14 '22 19:09 dlwh

Thanks for the report - it looks like we forgot to call _check_user_dtype_supported.

jakevdp avatar Sep 14 '22 19:09 jakevdp

It looks like this has been fixed.

jakevdp avatar Nov 07 '23 19:11 jakevdp