jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

mypy type checking seems to break in strict mode -- a mypy bug?

Open bluenote10 opened this issue 3 years ago • 2 comments

Following up on https://github.com/patrick-kidger/torchtyping/issues/41 I'm trying the same things here. However I'm not really having a big success so far with mypy. Am I doing anything wrong?

import torch
from jaxtyping import Float

dim1 = "dim1"


# Expected to work but fails with:
# error: Returning Any from function declared to return "Tensor"
def simple_test_a(x: Float[torch.Tensor, "dim1"]) -> torch.Tensor:
    return x


# Expected to work but fails with:
# error: Returning Any from function declared to return "float"
def simple_test_b(x: Float[torch.Tensor, "dim1"]) -> float:
    return x.item()


# Expected to error, but passes type checking
def simple_test_c(x: Float[torch.Tensor, "dim1"]) -> None:
    x.asdfasdfasdf()

VSCode (pyright) seems to do a little better, but apparently doesn't like the import:

image

bluenote10 avatar Dec 22 '22 15:12 bluenote10

I can't replicate your issue I'm afraid. Running:

import torch
from jaxtyping import Float

def simple_test_a(x: Float[torch.Tensor, "dim1"]) -> torch.Tensor:
    reveal_type(x)
    return x

def simple_test_b(x: Float[torch.Tensor, "dim1"]) -> float:
    reveal_type(x)
    return x.item()

def simple_test_c(x: Float[torch.Tensor, "dim1"]) -> None:
    reveal_type(x)
    x.asdfasdfasdf()

prints:

tmp.py:5: note: Revealed type is "torch._tensor.Tensor"
tmp.py:9: note: Revealed type is "torch._tensor.Tensor"
tmp.py:13: note: Revealed type is "torch._tensor.Tensor"
tmp.py:14: error: "Tensor" has no attribute "asdfasdfasdf"  [attr-defined]

This is with versions:

torch: 1.13.1
jaxtyping: 0.2.9
mypy: 0.991

As for VSCode, this issue is due to a now-resolved bug in pyright: https://github.com/microsoft/pyright/issues/4287 . Try updating your pyright version.

patrick-kidger avatar Dec 22 '22 15:12 patrick-kidger

Interesting, it seems to be related with strict mode. I can replicate your output when I run mypy in non-strict mode:

test.py:6: note: Revealed type is "torch._tensor.Tensor"
test.py:11: note: Revealed type is "torch._tensor.Tensor"
test.py:16: note: Revealed type is "torch._tensor.Tensor"
test.py:17: error: "Tensor" has no attribute "asdfasdfasdf"  [attr-defined]

As soon as I add a mypy.ini containing

[mypy]
strict = True

the output becomes:

test.py:5: error: Name "dim1" is not defined  [name-defined]
test.py:6: note: Revealed type is "Any"
test.py:7: error: Returning Any from function declared to return "Tensor"  [no-any-return]
test.py:10: error: Name "dim1" is not defined  [name-defined]
test.py:11: note: Revealed type is "Any"
test.py:12: error: Returning Any from function declared to return "float"  [no-any-return]
test.py:15: error: Name "dim1" is not defined  [name-defined]
test.py:16: note: Revealed type is "Any"

Note that even the revealed types change.

Using exactly the same package versions.

Looks like a mypy bug?

bluenote10 avatar Dec 22 '22 15:12 bluenote10