torchtyping icon indicating copy to clipboard operation
torchtyping copied to clipboard

pycharm shows an incorrect dtype when assigning torch.randn to a variable

Open jamesdsmith99 opened this issue 2 years ago • 5 comments

Hi,

I have recently started using this library, so i might be using it incorrectly, but linting seems to fail in pycharm when assigning the result of torch.randn to TensorType with a float dtype.

Here is an example:

Matrix = TensorType['h', 'w', float]
x: Matrix = torch.randn(5, 3)

The second line gets underlined with the following error:

Expected type 'TensorType[Any, Any, float]', got 'Tensor' instead

If i modify the second line to:

x: Matrix = torch.randn(5, 3).float()

The error goes away, but I would rather not do that as one of the plus sides of this library is to remove extra typing related code from my main logic. Having to add an implicit .float defeats the purpose of this library IMO.

From reading the docs this should work, torch.randn returns a tensor the the deafult dtype, and TensorTypes that have float in them should be of the default type.

jamesdsmith99 avatar Oct 03 '21 14:10 jamesdsmith99

Without having things set up in PyCharm myself it'll be a fair bit of work to diagnose this.

If you remove the float specifier and have only TensorType['h', 'w'] does that still produce an error? I'm just trying to gather some data on what raises an error and what doesn't. More broadly if you can track down what's causing the issue then I'd be happy to accept a PR.

patrick-kidger avatar Oct 04 '21 12:10 patrick-kidger

If you remove the float specifier and have only TensorType['h', 'w'] does that still produce an error?

I checked it and yes, there is still an error. Doing torch.randn(5, 3).float() works only because float() is untyped so PyCharm can't assume anything about the return type and doesn't emit any warnings.

Seems that torchtyping doesn't work with PyCharm's type checker at all, because no matter what I do there is always a warning when assigning Tensor to anything with TensorType type hint.

And I guess it's not surprising because TensorType is a subclass of Tensor so it complains when we try to assign an instance of the parent class to something expecting a subclass.

spietras avatar Apr 08 '22 12:04 spietras

Hi, is there any updates?

fzyzcjy avatar Dec 11 '22 05:12 fzyzcjy

Yes! I'd recommend trying jaxtyping. Despite the name, it actually works equally well for PyTorch.

In particular, it's designed to play much better with static type checkers.

patrick-kidger avatar Dec 11 '22 05:12 patrick-kidger

@patrick-kidger Interesting, thanks for the quick reply! (Never thought "jax" typing would work for "pytorch" before ;) )

fzyzcjy avatar Dec 11 '22 05:12 fzyzcjy