mypy
mypy copied to clipboard
Mypy errors on code that's trivial syntactic transformation of code that it does not error on
Bug Report
Mypy reports a type error for the eta expansion of code that it does not error on. See below for an example.
To Reproduce
from __future__ import annotations
from collections.abc import Callable
from typing import Generic, NamedTuple, TypeVar, TypeVarTuple
DType = TypeVar("DType")
DType2 = TypeVar("DType2")
Shape = TypeVarTuple("Shape")
Shape2 = TypeVarTuple("Shape2")
Dim1 = TypeVar("Dim1")
SeqLen = TypeVar("SeqLen")
BatchLen = TypeVar("BatchLen")
class ndarray(Generic[*Shape, DType]): ... # noqa: N801
def vmap(
fun: Callable[[ndarray[*Shape, DType]], ndarray[*Shape2, DType2]],
) -> Callable[[ndarray[Dim1, *Shape, DType]], ndarray[Dim1, *Shape2, DType2]]:
raise NotImplementedError
def fn(tkns: ndarray[BatchLen, SeqLen, float]):
_ = vmap(call)(tkns)
_ = vmap(lambda x: call(x))(tkns)
def call(input_: ndarray[SeqLen, float]) -> ndarray[SeqLen, float]:
raise NotImplementedError
Expected Behavior
Eta expansion (going from something like vmap(fn) to vmap(lambda x: fn(x)) should always be behavior preserving and also type preserving.
Actual Behavior
Instead Mypy approves of the first call but not the second.
main.py:26: error: Argument 1 to "call" has incompatible type "ndarray[*Shape, DType]"; expected "ndarray[Never, float]" [arg-type]
Your Environment
This is using the default setting on https://mypy-play.net/?mypy=latest&python=3.12 with Mypy at 1.11.2.