torchtyping icon indicating copy to clipboard operation
torchtyping copied to clipboard

TorchScript compatibility?

Open Linux-cpp-lisp opened this issue 3 years ago • 7 comments

Hi all,

This library looks very nice :)

Is TensorType compatible with the TorchScript compiler? As in, are the annotations transparently converted to torch.Tensor as far as torch.jit.script is concerned, allowing annotated modules/functions to be compiled? (I'm not worried about whether the type checking applied in TorchScript, just whether an annotated program that gets shape-checked in Python can be compiled down to TorchScript.)

Thanks!

Linux-cpp-lisp avatar Apr 09 '21 18:04 Linux-cpp-lisp

So I've been playing with this for a bit and unfortunately can't get it to work.

If you or someone else does manage to get this working, then I'd be happy to accept a PR on it.

For posterity:

  • TensorType does not currently inherit from torch.Tensor. This means that @torch.jit.script def f(x: TensorType) results in TensorType trying to be compiled, which fails.
  • This happens even when adding @torch.jit.ignore in various places. I think ignoring things only really works for free functions or methods of subclasses of torch.nn.Module.
  • Changing TensorType to inherit from torch.Tensor allows @torch.jit.script def f(x: TensorType), but @torch.jit.script def f(x: TensorType["b"]) still breaks, with error message Unknown type constructor TensorType from the TorchScript compiler.
  • Nothing I tried managed to fix that, and indeed a little googling suggests that it might be impossible, as type constructors are apparently parsed as strings: https://github.com/pytorch/pytorch/issues/29094 (And indeed I tried sneaky things like inheriting class TensorType(typing.List), without success.) My impression is that the only parameterised types admitted as annotations are the standard built-in ones like List.

patrick-kidger avatar Apr 09 '21 22:04 patrick-kidger

Hi @patrick-kidger, thanks for he quick answer! This level of arcane tinkering with TorchScript definitely sounds familiar to me... :grin:

The issue you link in the third bullet does make it look like there is nothing that can be done here until PyTorch resolves the underlying incompatibility with Python. (If I'm understanding this right you couldn't even do Annotated[torch.Tensor, something_else] since it wouldn't be parsable as a string, even though Python people worked hard to make Annotated backwards compatible.) Hopefully the PyTorch people are going to start using Python inspection for this like they said in the linked issue.

EDIT: it looks like fixes to this may have been merged? unclear: https://github.com/pytorch/pytorch/pull/29623

Linux-cpp-lisp avatar Apr 12 '21 17:04 Linux-cpp-lisp

Haha!

To answer the question, I agree that seems unclear on whether or not that issue is fixed. Either way, because of that or some other issue, our end use case doesn't seem to working at the moment.

patrick-kidger avatar Apr 12 '21 17:04 patrick-kidger

Hi! Is there any updates about that, guys?

kharitonov-ivan avatar Nov 09 '21 14:11 kharitonov-ivan

Not that I know about. As far as I know this is still a limitation in torchscript itself.

If this is a priority for you then you might like to try bringing this up with the torchscript team. They might know more about any possibilities for making this work.

patrick-kidger avatar Nov 09 '21 15:11 patrick-kidger

I have found a workaround. Let's say you have the following function

def f(x: TensorType["batch", "feature"]):
    return x.sum()

which you want to use in TorchScript. TorchScript does not like generic types in signatures, but we want to keep the dimension annotations somwhere for documentation purposes. We can work around this with a subclass.

import torch
from torchtyping import TensorType

class BatchedFeatureTensor(TensorType["batch", "feature"]):
    pass

@torch.jit.script
def f(x: BatchedFeatureTensor):
    return x.sum()

print(f(torch.tensor([[-1.0, 2.0, 1.2]])))
print(f.code)

# => tensor(2.2000)
# => def f(x: Tensor) -> Tensor:
# =>   return torch.sum(x)

martenlienen avatar Apr 22 '22 15:04 martenlienen

Found another way to deal with torchscript. Just paste the code and call patch_torchscript() before exporting.

import re
import typing as tp

import torch

ttp_regexp = re.compile(r"TensorType\[[^\]]*\]")
torchtyping_replacer = "torch.Tensor"


def _replace_torchtyping(source_lines: tp.List[str]) -> tp.List[str]:

    # Join all lines
    cat_lines = "".join(source_lines)

    # Quick exit, if torchtyping is not used
    if ttp_regexp.search(cat_lines) is None:
        return source_lines

    # Replace TensorType
    cat_lines = ttp_regexp.sub(torchtyping_replacer, cat_lines)

    # Split into lines
    source_lines = cat_lines.split("\n")
    source_lines = [f"{i}\n" for i in source_lines]

    return source_lines


def _torchtyping_destruct_wrapper(func: tp.Callable) -> tp.Callable:
    def _wrap_func(obj: tp.Any, error_msg: tp.Optional[str] = None) -> tp.Tuple[tp.List[str], int, tp.Optional[str]]:
        srclines, file_lineno, filename = func(obj, error_msg)
        srclines = _replace_torchtyping(srclines)
        return srclines, file_lineno, filename

    return _wrap_func


def patch_torchscript() -> None:
    """
    Patch torchscript to work with torchtyping.

    Returns: None.

    """
    # Patch _sources if torch >= 1.10.0, else torch.jit.frontend
    if hasattr(torch, "_sources"):
        src = getattr(torch, "_sources")  # noqa: B009
    else:
        src = getattr(torch.jit, "frontend")  # noqa: B009

    src.get_source_lines_and_file = _torchtyping_destruct_wrapper(src.get_source_lines_and_file)

Datasciensyash avatar Sep 12 '22 13:09 Datasciensyash