jax
jax copied to clipboard
Make DTypeLike runtime checkable
@patrick-kidger, as reported in patrick-kidger/jaxtyping#165, here is the PR to make SupportsDType runtime checkable so that DTypeLike can be used with runtime type checkers.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
One note: SupportsDType is not entirely accurate, because there is at least one example of a type that matches its protocol and yet is not compatible as an argument to dtype:
>>> import numpy as np
>>> x = np.arange(4)
>>> isinstance(x, SupportsDType) # with runtime_checkable
True
>>> np.dtype(x)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 np.dtype(x)
TypeError: Cannot construct a dtype from an array
If you want an accurate runtime check of whether a particular object is valid as a dtype, you'll have to go beyond protocols because (unfortunately) Python type annotations are not flexible enough to express the actual semantics of np.dtype.
Given that, do you think this change is still important?
I'd rather not use Protocols. What about being more explicit by removing SupportsDType which is not used except in DTypeLike and have something like that:
NumberType = Union[jnp.float32, jnp.int32] # to be completed with all jax scalar number types
ScalarType = Union[jnp.bool_, NumberType]
DTypeLike = Union[
str, # like 'float32', 'int32'
type[Union[bool, int, float, complex, np.bool_, np.number, ScalarType]],
np.dtype,
]
That's unfortunate that unlike Numpy, there is no superclass for all the jax types.
I'd rather not use Protocols.
To be close to accurate in describing valid inputs of numpy.dtype, you need protocols. This is a NumPy thing, not a JAX thing (jnp.dtype is just an alias of np.dtype). np.dtype accepts anything with a dtype attribute, except NumPy arrays themselves. As far as I know, Python provides no way to accurately statically annotate that API.
That's unfortunate that unlike Numpy, there is no superclass for all the jax types.
Again, this is an issue with numpy.dtype, so I'm not sure what JAX could do differently to address it.
https://github.com/google/jax/issues/22144 link