torchtyping icon indicating copy to clipboard operation
torchtyping copied to clipboard

Optionally check consistency across function calls within a checked function.

Open patrick-kidger opened this issue 3 years ago • 0 comments

Essentially:

typeguard uses a "memo" object as a place to store information over the duration of a function call. At the moment, the checking is primarily performed by:

  • Storing extra information in the memo (specifically the pairs of TensorTypes and the corresponding actual tensors passed as arguments);
  • Then parsing all of these to perform the extra checking.

It should be possible to extend this to check consistency across nested function calls, e.g.

def f() -> TensorType["x", "y"]:
    return torch.rand(2, 3)

def g(tensor: TensorType["y", "x"]):
    pass

g(f())

should raise an error, as "x" and "y" get both 2 and 3 passed as sizes.

The solution should be to:

  • Create an additional a thread-local storage.
  • For the duration of a memo's (=function call) existence, have it register itself in the storage.
  • Have each memo compare the inferred sizes of dimensions "x", "y" etc. against those of other memos in the storage.
  • Raise an error if they don't match.
  • Have the memo only perform the checking for function calls in the same file, to avoid incompatibility between uses of torchtyping in library code and user code.

Will need to think about how to do this optimally. A naive implementation involves every call to _check_memo looking in the storage and doing this extra checking, but that will be O(n^2) in the depth of the call stack n.

patrick-kidger avatar Apr 01 '21 17:04 patrick-kidger