Plum 2 and np.typing.NDArray
Hi @wesselb
I am testing out https://github.com/wesselb/plum/pull/73 due to the improved numpy support. In the very first example there, you show how to dispatch based on shape and type. In my use cases, I am only interested in dispatching based on type. Since that is actually supported by np.typing.NDArray, I thought I could just use that, but that doesn't seem to be the case.
I report a minimal example below, in case you want to have a look before the plum 2 release.
from plum import dispatch, parametric
from typing import Any, Optional, Tuple, Union
import numpy as np
import numpy.typing
class NDArrayMeta(type):
def __instancecheck__(self, x):
if self.concrete:
shape, dtype = self.type_parameter
else:
shape, dtype = None, None
return (
isinstance(x, np.ndarray)
and (shape is None or x.shape == shape)
and (dtype is None or x.dtype == dtype)
)
@parametric
class NDArray(np.ndarray, metaclass=NDArrayMeta):
@classmethod
@dispatch
def __init_type_parameter__(
cls,
shape: Optional[Tuple[int, ...]],
dtype: Optional[Any],
):
"""Validate the type parameter."""
return shape, dtype
@classmethod
@dispatch
def __le_type_parameter__(
cls,
left: Tuple[Optional[Tuple[int, ...]], Optional[Any]],
right: Tuple[Optional[Tuple[int, ...]], Optional[Any]],
):
"""Define an order on type parameters. That is, check whether
`left <= right` or not."""
shape_left, dtype_left = left
shape_right, dtype_right = right
return (
(shape_right is None or shape_left == shape_right)
and (dtype_right is None or dtype_left == dtype_right)
)
@dispatch
def f(x: NDArray[None, np.int32]):
print("An int array!")
@dispatch
def f(x: NDArray[None, np.float64]):
print("A float array!")
print("BEGIN f")
f(np.ones((3, 3), np.int32))
f(np.ones((2, 2), np.float64))
print("END f")
@dispatch
def g(x: np.typing.NDArray[np.int32]):
print("An int array!")
@dispatch
def g(x: np.typing.NDArray[np.float64]):
print("A float array!")
print("BEGIN g")
g(np.ones((3, 3), np.int32))
g(np.ones((2, 2), np.float64))
print("END g")
and the corresponding output
BEGIN f
An int array!
A float array!
END f
BEGIN g
/usr/local/lib/python3.11/dist-packages/plum/signature.py:203: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
annotation = resolve_type_hint(p.annotation)
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
return _is_faithful(resolve_type_hint(x))
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not determine whether `numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]` is faithful or not. I have concluded that the type is not faithful, so your code might run with subpar performance. Please open an issue at https://github.com/wesselb/plum.
return _is_faithful(resolve_type_hint(x))
/usr/local/lib/python3.11/dist-packages/plum/signature.py:203: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
annotation = resolve_type_hint(p.annotation)
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not resolve the type hint of `numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]`. I have ended the resolution here to not make your code break, but some types might not be working correctly. Please open an issue at https://github.com/wesselb/plum.
return _is_faithful(resolve_type_hint(x))
/usr/local/lib/python3.11/dist-packages/plum/type.py:261: UserWarning: Could not determine whether `numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]` is faithful or not. I have concluded that the type is not faithful, so your code might run with subpar performance. Please open an issue at https://github.com/wesselb/plum.
return _is_faithful(resolve_type_hint(x))
Traceback (most recent call last):
File "/tmp/p.py", line 72, in <module>
g(np.ones((3, 3), np.int32))
File "/usr/local/lib/python3.11/dist-packages/plum/function.py", line 342, in __call__
self._resolve_pending_registrations()
File "/usr/local/lib/python3.11/dist-packages/plum/function.py", line 237, in _resolve_pending_registrations
self._resolver.register(subsignature)
File "/usr/local/lib/python3.11/dist-packages/plum/resolver.py", line 58, in register
existing = [s == signature for s in self.signatures]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/plum/resolver.py", line 58, in <listcomp>
existing = [s == signature for s in self.signatures]
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/plum/util.py", line 132, in __eq__
return self <= other <= self
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/plum/signature.py", line 132, in __le__
[TypeHint(x) <= TypeHint(y) for x, y in zip(self_types, other_types)]
File "/usr/local/lib/python3.11/dist-packages/plum/signature.py", line 132, in <listcomp>
[TypeHint(x) <= TypeHint(y) for x, y in zip(self_types, other_types)]
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/beartype/door/_doormeta.py", line 148, in __call__
_HINT_KEY_TO_WRAPPER.cache_or_get_cached_func_return_passed_arg(
File "/usr/local/lib/python3.11/dist-packages/beartype/_util/cache/map/utilmapbig.py", line 231, in cache_or_get_cached_func_return_passed_arg
value = value_factory(arg)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/beartype/door/_doormeta.py", line 220, in _make_wrapper
wrapper_subclass = get_typehint_subclass(hint)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/beartype/door/_doordata.py", line 108, in get_typehint_subclass
raise BeartypeDoorNonpepException(
beartype.roar.BeartypeDoorNonpepException: Type hint numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]] invalid (i.e., either PEP-noncompliant or PEP-compliant but currently unsupported by "beartype.door.TypeHint").
Thanks!
Hey @francesco-ballarin,
Thanks for opening an issue about this. :) You're right that numpy.typing currently won't work. The problem is that numpy.typing types are unfortunately not functional themselves:
>>> isinstance(1, npt.NDArray[int])
TypeError: isinstance() argument 2 cannot be a parameterized generic
They are type hints like objects from typing, which won't of their own, but need additional support. @beartype seems to come very close, but unfortunately doesn't quite get it right:
>>> from beartype.door import is_bearable, TypeHint
>>> is_bearable(np.ones(1, int), npt.NDArray[int]) # Nice!
True
>>> TypeHint(npt.NDArray[int]) # Nice!
TypeHint(numpy.ndarray[typing.Any, numpy.dtype[int]])
>>> TypeHint(npt.NDArray[int]) == TypeHint(npt.NDArray[float]) # :(
TypeHint(npt.NDArray[int]) == TypeHint(npt.NDArray[float])
The reason that you're getting a beartype.roar.BeartypeDoorNonpepException whereas I'm not might be due to a different Python version. (I ran the above using Python 3.9.) EDIT: This seems to be the case; see the issue linked below.
I've opened an issue on @beartype to see what @leycec thinks about this.