Pyright incorrectly binds type variables in higher-order functions
For the following code, pyright correctly binds the type variables for the nested higher order functions call. The revealed type of map2 is correct.
from collections.abc import Callable
def curry[First, *Rest, Result](function: Callable[[First, *Rest], Result]) -> Callable[[*Rest], Callable[[First], Result]]:
return lambda *rest: lambda first: function(first, *rest)
@curry
@curry
def map1[From, To, Arg](value: Arg, first: Callable[[Arg], From], second: Callable[[From], To]) -> To:
return second(first(value))
@curry
@curry
def map1_copy[From, To, Arg](value: Arg, first: Callable[[Arg], From], second: Callable[[From], To]) -> To:
return second(first(value))
# While `map1` is used to map the result of a function of one argument like this:
# (From -> To) -> (Arg1 -> From) -> (Arg1 -> To)
# `map2` is meant to map the result of a function of two arguments like this:
# (From -> To) -> (Arg1 -> Arg2 -> From) -> (Arg1 -> Arg2 -> To)
map2 = map1(map1)(map1)
reveal_type(map2) # ((From(2)@map1) -> To(2)@map1) -> ((((Arg(1)@map1) -> ((Arg(2)@map1) -> From(2)@map1))) -> ((Arg(1)@map1) -> ((Arg(2)@map1) -> To(2)@map1)))
# Negate the sum of two ints
reveal_type(map2(int.__neg__)(curry(int.__add__))) # (int) -> ((int) -> int)
However, if I change the definition of map2 using the exact copy of map1, pyright is no longer able to successfully bind the type variables resulting in the wrong type of map2.
map2 = map1(map1_copy)(map1_copy)
reveal_type(map2) # ((From(1)@map1_copy) -> To(1)@map1_copy) -> ((((((From(1)@map1_copy) -> To(1)@map1_copy)) -> ((Arg(1)@map1_copy) -> From(1)@map1_copy))) -> ((((From(1)@map1_copy) -> To(1)@map1_copy)) -> ((Arg(1)@map1_copy) -> To(1)@map1_copy)))
reveal_type(map2(int.__neg__)(curry(int.__add__))) # ((int) -> int) -> ((int) -> int)
As a consequence, this also results in a false positive error.
Argument of type "(int) -> ((int) -> int)" cannot be assigned to parameter of type "((int) -> int) -> ((Arg(1)@map1_copy) -> int)"
Type "(int) -> ((int) -> int)" is incompatible with type "((int) -> int) -> ((int) -> int)"
Parameter 1: type "(int) -> int" is incompatible with type "int"
"function" is incompatible with "int"
For comparison, mypy's behavior doesn't change depending on whether I use map1 or map1_copy to define map2 which is the expected behavior.
I'll note that even with the working definition of map2, if I try to define a function that converts the sum of two ints into a string, pyright gives me an error.
map2 = map1(map1)(map1)
sum_to_str = map2(str)(curry(int.__add__))
Argument of type "(int) -> ((int) -> int)" cannot be assigned to parameter of type "(Arg(1)@map1) -> Overload[(Arg(2)@map1) -> object, (Arg(2)@map1) -> ReadableBuffer]"
Type "(int) -> ((int) -> int)" is incompatible with type "(int) -> Overload[(int) -> object, (int) -> ReadableBuffer]"
Function return type "(int) -> int" is incompatible with type "Overload[(int) -> object, (int) -> ReadableBuffer]"
One or more overloads of "" is not assignable
Type "(int) -> int" is incompatible with type "(int) -> ReadableBuffer"
Function return type "int" is incompatible with type "ReadableBuffer"
I'm not sure if this is caused by the same issue or not.
EDIT:
I've found a simpler example:
from collections.abc import Callable
def f[A, B](_: Callable[[A], B]) -> Callable[[Callable[[], A]], B]: ...
reveal_type(f(str)) # (Overload[() -> object, () -> Buffer]) -> str
f(str)(lambda: 1) # error here
Argument of type "() -> Literal[1]" cannot be assigned to parameter of type "Overload[() -> object, () -> ReadableBuffer]"
One or more overloads of "" is not assignable
Type "() -> Literal[1]" is incompatible with type "() -> ReadableBuffer"
Function return type "Literal[1]" is incompatible with type "ReadableBuffer"
"Literal[1]" is incompatible with protocol "Buffer"
"__buffer__" is not present
This has more to do with overloads handling. @erictraut, should I file a separate issue for this one?
This issue (the one at the top of this thread) was caused by the same bug as #8852. This will be addressed in the next release.
The other "simpler example" is unrelated. On first inspection, it doesn't appear to be a bug. The str constructor is overloaded, and you're attempting to assign an incompatible callable (a lambda) to this overloaded type. So I think pyright is correct in generating an error here.
This is addressed in pyright 1.1.380.