vision
vision copied to clipboard
Clean up prototype transforms before migration
The transforms in torchvision.transforms are fully JIT scriptable. To achieve this we needed a lot of helper methods and unnecessary strict or plain wrong annotations.
The revamped transforms in torchvision.prototype.transforms will no longer be JIT scriptable due to their ability to handle more complex inputs. Thus, we should clean-up our code removing everything that was just added to appease torch.jit.script. This should only happen as one of the last steps before migration to the main area.
cc @vfdev-5 @datumbox @bjuncek
This comment only applies in case we decide to reinstate the high-level dispatchers. Leaving this here so it doesn't get lost.
Since torchscript is statically compiled we need special handling whenever we rely on duck typing. The most notable example are sequences, which mostly means tuple's and list's.
I propose to use "correct" annotations for the low-level kernels:
- If the argument has a fixed number of elements, use
Tuple. For example, animage_sizeis is always two integers, and thus we should useTuple[int, int]and not allowList[int]. Note that this is different in instances where the type actually makes a difference in behavior, e.g. using anintor aTuple[int, int]withresize. Since torchscript supportsUnion, we should be able to doUnion[int, Tuple[int, int]].[^1] - If the argument has an unknown number of elements, use
List. There are probably only a few cases that actually need this. Anything related to channels, e.g. themeanandstdarguments ofnormalizeor thefillargument, are the only ones coming to mind. On the flip side, arguments likepadding: List[int]should probably beUnion[int, Tuple[int, int], Tuple[int, int, int, int]].
The only users that are affected by this are JIT users, since everyone else either uses the higher level abstractions and duck typing in the low-level kernels still works. That should clean up our code quite a bit.
Somewhat more controversial, I would also remove the type checks from the low-level kernels. Something like
if not isinstance(arg, (tuple, list)):
raise ValueError
is very restrictive since the kernel would work with any sequence. It should be isinstance(arg, collections.abc.Sequence), but that is not supported by torchscript. However, we can have checks like that in high-level dispatchers, since they are not JIT-scriptable in general.
[^1]: I'm aware of comments like https://github.com/pytorch/vision/blob/ae1d70713a542aac905cf403844a21b38b2be593/torchvision/transforms/functional.py#L1116-L1118, but it seems they come from a time where Union was not properly supported. I was able remove this limitation with a minimal patch locally.
This might no longer be relevant if #6584 is successful. Keeping open in case it fails. We can close later.
Regardless of #6584, the transforms will not be JIT scriptable. Thus, cleaning them up is still relevant.
Good point, it's not about the functional, it's also about the Transform classes.
We discussed this offline and the consensus is that we can have a go at it for the constructors.
I wrestled with this the last few days and the result is somewhat sobering. Imagine we currently have a transform like
import collections.abc
from typing import *
import torch
from torchvision.prototype import transforms
class Transform(transforms.Transform):
def __init__(self, range: List[float]):
super().__init__()
if not (
isinstance(range, collections.abc.Sequence)
and all(isinstance(item, float) for item in range)
and len(range) == 2
and range[0] < range[1]
):
raise TypeError(f"`range` should be a sequence of two increasing float values, but got {range}")
self._dist = torch.distributions.Uniform(range[0], range[1])
This works perfectly fine for mypy as well as in eager mode. Meaning, you can pass anything that ducks as a sequence of two floats.
As discussed above, List[float] is a sub-optimal annotation here. List[T], Sequence[T], Tuple[T, ...], and so on, imply arbitrary number of elements. Since this annotation was out of necessity for JIT, in theory, we should be able to switch to Tuple[float, float], right?
Wrong :disappointed: Unfortunately, mypy and Python disagree on the definition of a sequence. While isinstance((1.0, 2.0), collections.abc.Sequence) evaluates to True, mypy evaluates it to False :boom: To be more precise, Tuple[float, float] is not treated as a sequence, while Tuple[float, ...] is. This is really awkward from a Python perspective, but in statically typed languages tuples are not iterable in general, so this is not completely bonkers.
Still, this makes it really hard to use the right annotations here while keeping the type checks. Switching the annotation to Tuple[float, float] above and running mypy yields
main.py:13: error: <nothing> has no attribute "__iter__" (not iterable) [attr-defined]
and all(isinstance(item, float) for item in range)
^
main.py:15: error: Value of type <nothing> is not indexable [index]
and range[0] < range[1]
^
main.py:18: error: Value of type <nothing> is not indexable [index]
self._dist = torch.distributions.Uniform(range[0], range[1])
^
Found 3 errors in 1 file (checked 1 source file)
The consequence from mypy not treating Tuple[float, float] as sequence is that for it isinstance(range, collections.abc.Sequence) can never evaluate to True and thus, we always raise the TypeError. All other statements where range is used will error, since they are "unreachable".
I currently see three ways forward here:
-
Drop the
isinstance(range, collections.abc.Sequence)check. The sequence protocol is defined by implementing__getitem__and__len__. Since we are accessing both later in the checks anyway, it is impossible for non-sequence objects to get past this guard. Still, the user would see the native error instead of our custom one. For example, imagine someone passes an iterable object with a known lengthclass Foo: def __len__(self): return 2 def __iter__(self): yield 0.0 yield 1.0 Transform(Foo())TypeError: 'Foo' object is not subscriptable -
Use
TypeGuard's. By replacing the vanilla sequence check withfrom typing_extensions import TypeGuard T = TypeVar("T") def _is_tuple_ducktype(obj: Sequence[T]) -> TypeGuard[Tuple[T, ...]]: return isinstance(obj, collections.abc.Sequence)mypyis happy again. Still, this has some downsides:typing.TypeGuardis only available for Python >= 3.10. So far we have and unpinned dependency https://github.com/pytorch/vision/blob/23d3f78aeea9329a8257e17b90c37f6f2016c171/setup.py#L61 but we will have to pin it totyping_extensions >= 4; python_version < '3.10'.- We'll make the code more complex just for
mypy. Although I'm usually pretty enthusiastic when it comes to annotations, I feel in this particular instance it is a step the wrong direction.
-
Leave everything as is. Since we are not PEP 561 compliant anyway,
torchvisionwill not be analyzed bymypyin third party projects. Meaning, the annotations are mostly for us.
Thoughts?
I would go with option 3. IMO the current typing annotations provide sufficient info without introducing complex solutions like 2.