jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Functions without type hints and import hook

Open nimashoghi opened this issue 1 year ago • 1 comments

Hi!

First of all, thanks for the awesome library. This library has made my code much more understandable, and the runtime type-checking with beartype has been immensely useful.

I'm currently working on an existing (PyTorch) codebase that did not previously use jaxtyping type hints, and I'm gradually adding type hints to areas that I work on. As a result, I have a handful of cases where I'm not using function argument/return type hints, but am rather using isinstance checks, e.g.,:


def my_func(x):
    x = ... # some operation I'm not touching
    # my code
    assert isinstance(x, Float[torch.Tensor, "bsz channels"])
    x = do_some_other_stuff(x)
    # ... and then the rest of the code for `my_func`

In these cases, due to the way the import hooking is currently set up, I'm running into some very strange and unexpected behavior. Specifically, it seems like axis bindings in these kinds of functions just get ignored and do not get registered in the memo_stack.

This seems to be because, in the case above, my_func does not have jaxtyping type hints in its args/return types and thus will not be registered using the import hook.

For now, I've patched jaxtypings' import hook code (_import_hook.py) to also register all functions with isinstance expressions:

def _has_isinstance(func_def):
    for node in ast.walk(func_def):
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "isinstance":
            return True
    return False

And then checking for this in JaxtypingTransformer, changing the following lines to:

    def visit_FunctionDef(self, node: ast.FunctionDef):
        has_annotated_args = any(arg for arg in node.args.args if arg.annotation)
        has_annotated_return = bool(node.returns)
        has_isinstance = _has_isinstance(node)
        if has_annotated_args or has_annotated_return or has_isinstance:

This is a hacky fix but works in my case. Would love to hear what your thoughts on fixing this would be (and if a similar fix is warranted for now).

Thanks!

nimashoghi avatar Apr 02 '24 00:04 nimashoghi

Ah, good catch!

Maybe we should just remove the has_annotated_args or has_annotated_return call? I don't think it's important, it was just a minor efficiency thing.

I'd be happy to take a PR on this.

patrick-kidger avatar Apr 03 '24 18:04 patrick-kidger

Closing as accomplished in #205, which corresponds to jaxtyping version 0.2.29.

patrick-kidger avatar Aug 18 '24 19:08 patrick-kidger