jax
jax copied to clipboard
jax.random: warn on unsupported dtypes
Fixes #12364
Explanation for the default changes: jnp.float_ and jnp.int_ are currently equal to float64 and int64, respectively, and are cannonicalized upon use. We can't use them as default arguments here because they trigger the unsupported dtype warning when X64 is False... on the other hand, the scalar types int and float do not trigger the unsupported dtype warning, and are canonicalized to the appropriate default type, so they have the desired behavior.
Is this consistent with our "remove x64 flag" plans? Or e.g. should we just switch the defaults to be 32bit?