plum icon indicating copy to clipboard operation
plum copied to clipboard

Do not trigger type_of on return type hint

Open mofeing opened this issue 3 years ago • 6 comments

I am working with nested NumPy arrays (arrays of arrays) and I want to dispatch based on its nesting level.

from plum import dispatch

def something(a: NestedArray[0]):
    pass

def something(a: NestedArray[1]):
    pass

I'm using the following hook for that:

from plum import parametric, type_of

@parametric(runtime_type_of=True)
class NestedArray(np.ndarray):
    """A type for recursive numpy arrays (array of arrays) where the type parameter specifies the nesting level."""
    pass


@type_of.dispatch
def type_of(x: np.ndarray):
    level = 0
    while isinstance(x.flat[0], np.ndarray):
        level += 1
        x = x.flat[0]

    return NestedArray[level]

It works like a charm. But I run into a really annoying side-effect when I have a dispatched function with numpy.ndarray as a return type hint:

@dispatch
def to_numpy(a) -> numpy.ndarray:
    ...
    return numpy.zeros(4) # returns a ndarray, zeros for example

This code crashes because it runs type_of on the returning object, which returns a Type. Then compares it against the type hint of to_numpy, which fails.

...
  File "plum/function.py", line 537, in plum.function.Function.__call__
  File "plum/function.py", line 173, in plum.function._convert
  File "plum/function.py", line 537, in plum.function.Function.__call__
  File "/home/mofeing/Develop/rosnet/.venv/lib/python3.8/site-packages/plum/promotion.py", line 32, in convert
    return _convert.invoke(type_of(obj), type_to)(obj, type_to)
  File "plum/function.py", line 552, in plum.function.Function.invoke.wrapped_method
  File "/home/mofeing/Develop/rosnet/.venv/lib/python3.8/site-packages/plum/promotion.py", line 43, in _convert
    if type_from <= type_to:
  File "/home/mofeing/Develop/rosnet/.venv/lib/python3.8/site-packages/plum/util.py", line 43, in __ge__
    return other.__le__(self)
TypeError: descriptor '__le__' requires a 'numpy.ndarray' object but received a 'Type'

If I remove the return type hint, it works.

I guess this is a bug because there is no way (yet) to dispatch based on return type. Also because the function should be returning the object, not the hooked type.

mofeing avatar Feb 08 '22 17:02 mofeing

Hey @mofeing! Thanks for opening an issue. What appears to fix the issue is returning the type wrapped in a representation used internally in the package. I believe that the following small change should fix the error:

from plum import parametric, type_of, ptype

@parametric(runtime_type_of=True)
class NestedArray(np.ndarray):
   """A type for recursive numpy arrays (array of arrays) where the type parameter specifies the nesting level."""
   pass


@type_of.dispatch
def type_of(x: np.ndarray):
   level = 0
   while isinstance(x.flat[0], np.ndarray):
       level += 1
       x = x.flat[0]

   return ptype(NestedArray[level])

Would you be able to check if this resolves all side effects?

The conversion with ptype should be done automatically, but that's clearly not the case, so this is indeed a bug. Thanks for catching this! I'm going to look into it more closely.

wesselb avatar Feb 08 '22 17:02 wesselb

I guess this is a bug because there is no way (yet) to dispatch based on return type

you can't dispatch based on the return type. What this is doing is actually just converting the output type of the function to the type you specified (np.ndarray) and erroring if it can't.

This mimicks julia's syntax. however, i find it relatively annoying because 99% of the time i don't want this feature, and i just want a (no effect) return type annotation....

PhilipVinc avatar Feb 08 '22 17:02 PhilipVinc

@wesselb maybe a better thing to do would be to check when type_of has a new rule and validate the return type? Doing this at runtime would slow down parametric dispatch that is already quite slow...

PhilipVinc avatar Feb 08 '22 17:02 PhilipVinc

maybe a better thing to do would be to check when type_of has a new rule and validate the return type?

Do you mean that the user would be required to specify a return type for type_of which must be of the right type? The solution I had in mind would be to define a new function ptype_of which simply applies type_of then ptype and use ptype_of in the code which implicitly assumes to receive ptypes rather than types.

wesselb avatar Feb 08 '22 18:02 wesselb

@wesselb it works now with the call to ptype! thanks

mofeing avatar Feb 09 '22 09:02 mofeing

@mofeing I’m glad to hear that! I’m leaving this issue open until the bug has been fixed.

wesselb avatar Feb 09 '22 09:02 wesselb