torchtyping icon indicating copy to clipboard operation
torchtyping copied to clipboard

mypy not compatible with any named axes?

Open zplizzi opened this issue 2 years ago • 4 comments

When I specify a type like TensorType["batch_size", "num_channels", "x", "y"], I get a mypy error like error: Name "batch_size" is not defined for each of the named axes. Is this expected? Am I doing something wrong? This is with the most recent mypy, 0.950.

zplizzi avatar May 21 '22 02:05 zplizzi

This is expected - mypy thinks that the string is being used as part of a forward reference, rather than as a literal string. Python's typing system can be a bit of a mess in edge cases like this.

One solution is to actually define some objects with the name of these strings. Another is to use the appropriate annotations to have mypy ignore the error.

patrick-kidger avatar May 21 '22 10:05 patrick-kidger

Got it - no worries, I understand the constraints here. It might be helpful to update the section of the documentation discussing mypy to explain this more clearly, though. When I read that mypy was "mostly" supported I would have expected that this core feature would work without hacks.

zplizzi avatar May 22 '22 16:05 zplizzi

Latest version of Pyright (1.1.262) starts to throw the similar errors. ("batch_size" is not defined)

stvhuang avatar Jul 22 '22 03:07 stvhuang

Another is to use the appropriate annotations to have mypy ignore the error.

How exactly is this supposed to work? Even with the following

from torchtyping import TensorType  # type: ignore


def batch_outer_product(
    x: TensorType[
        "batch",  # type: ignore
        "x_channels",  # type: ignore
    ],
    y: TensorType[
        "batch",  # type: ignore
        "y_channels",  # type: ignore
    ],
) -> TensorType[
    "batch",  # type: ignore
    "x_channels",  # type: ignore
    "y_channels",  # type: ignore
]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

I'm getting:

test.py:6: error: Name "batch" is not defined  [name-defined]
test.py:6: error: Name "x_channels" is not defined  [name-defined]
test.py:7: error: Unused "type: ignore" comment
test.py:8: error: Unused "type: ignore" comment
test.py:10: error: Name "batch" is not defined  [name-defined]
test.py:10: error: Name "y_channels" is not defined  [name-defined]
test.py:11: error: Unused "type: ignore" comment
test.py:12: error: Unused "type: ignore" comment
test.py:14: error: Name "batch" is not defined  [name-defined]
test.py:14: error: Name "x_channels" is not defined  [name-defined]
test.py:14: error: Name "y_channels" is not defined  [name-defined]
test.py:15: error: Unused "type: ignore" comment
test.py:16: error: Unused "type: ignore" comment
test.py:17: error: Unused "type: ignore" comment

I assume this line in the documentation is no longer valid, right?

Additionally mypy has a bug which causes it crash on any file using the str: int or str: ... notation, as in TensorType["batch": 10].

The underlying issue (https://github.com/python/mypy/issues/10266) has been closed.

bluenote10 avatar Dec 22 '22 09:12 bluenote10