torchtyping
torchtyping copied to clipboard
Optionally check consistency across function calls within a checked function.
Essentially:
typeguard
uses a "memo" object as a place to store information over the duration of a function call. At the moment, the checking is primarily performed by:
- Storing extra information in the memo (specifically the pairs of
TensorType
s and the corresponding actual tensors passed as arguments); - Then parsing all of these to perform the extra checking.
It should be possible to extend this to check consistency across nested function calls, e.g.
def f() -> TensorType["x", "y"]:
return torch.rand(2, 3)
def g(tensor: TensorType["y", "x"]):
pass
g(f())
should raise an error, as "x" and "y" get both 2 and 3 passed as sizes.
The solution should be to:
- Create an additional a thread-local storage.
- For the duration of a memo's (=function call) existence, have it register itself in the storage.
- Have each memo compare the inferred sizes of dimensions "x", "y" etc. against those of other memos in the storage.
- Raise an error if they don't match.
- Have the memo only perform the checking for function calls in the same file, to avoid incompatibility between uses of
torchtyping
in library code and user code.
Will need to think about how to do this optimally. A naive implementation involves every call to _check_memo
looking in the storage and doing this extra checking, but that will be O(n^2)
in the depth of the call stack n
.