jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

IPython `inspect.getsource()` failure due to incorrect co_firstlineno

Open davideger opened this issue 1 year ago • 4 comments

This colab shows an unexpected side effect of enabling automatic jaxtype checking in IPython: It causes inspect.getsource's to retrieve incorrect source text for a given function.

That is, if I run these two cells:

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

def arange(i: int) -> jaxtyping.Int32[torch.Tensor, "i"]:
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))


def ones(i: int) -> jaxtyping.Int32[torch.Tensor, "{i}"]:
    return arange(i) - arange(i) + 1
import inspect

inspect.getsource(ones)

Then I get:

'def ones(i: int) -> jaxtyping.Int32[torch.Tensor, "{i}"]:\n    return arange(i) - arange(i) + 1\n'

But if I turn on type checking by using:

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

Then running the same inspect.getsource(ones) yields:

'def where(q, a, b):\n    "Use this function to replace an if-statement."\n    return (q * a) + (~q) * b\n'

davideger avatar Jan 10 '24 05:01 davideger