jaxtyping
jaxtyping copied to clipboard
IPython `inspect.getsource()` failure due to incorrect co_firstlineno
This colab shows an unexpected side effect of enabling automatic jaxtype checking in IPython: It causes inspect.getsource
's to retrieve incorrect source text for a given function.
That is, if I run these two cells:
def where(q, a, b):
"Use this function to replace an if-statement."
return (q * a) + (~q) * b
def arange(i: int) -> jaxtyping.Int32[torch.Tensor, "i"]:
"Use this function to replace a for-loop."
return torch.tensor(range(i))
def ones(i: int) -> jaxtyping.Int32[torch.Tensor, "{i}"]:
return arange(i) - arange(i) + 1
import inspect
inspect.getsource(ones)
Then I get:
'def ones(i: int) -> jaxtyping.Int32[torch.Tensor, "{i}"]:\n return arange(i) - arange(i) + 1\n'
But if I turn on type checking by using:
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype
Then running the same inspect.getsource(ones)
yields:
'def where(q, a, b):\n "Use this function to replace an if-statement."\n return (q * a) + (~q) * b\n'