jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.random: warn on unsupported dtypes

Open jakevdp opened this issue 3 years ago • 1 comments

Fixes #12364

Explanation for the default changes: jnp.float_ and jnp.int_ are currently equal to float64 and int64, respectively, and are cannonicalized upon use. We can't use them as default arguments here because they trigger the unsupported dtype warning when X64 is False... on the other hand, the scalar types int and float do not trigger the unsupported dtype warning, and are canonicalized to the appropriate default type, so they have the desired behavior.

jakevdp avatar Sep 14 '22 19:09 jakevdp

Is this consistent with our "remove x64 flag" plans? Or e.g. should we just switch the defaults to be 32bit?

mattjj avatar Sep 16 '22 17:09 mattjj