beartype icon indicating copy to clipboard operation
beartype copied to clipboard

[Feature Request] @beartype + `jaxtyping` + forward references = 💣 🔥

Open sean-roelofs-ai opened this issue 4 months ago • 2 comments

The @beartype decorator fails at load time when decorating a function that uses self-referential forward references.

Since beartype is ultimately the package that raises the error, I made the issue here instead of in the jaxtyped repository.

Interestingly, @jaxtyped does work, though it fails to meet my use case for other reasons

Minimally reproducible example:

from __future__ import annotations
from jaxtyping import Shaped, jaxtyped
from beartype import beartype


class MyClass:
    
    # @jaxtyped(typechecker=beartype) # <- this works
    @beartype # <- this fails
    def foo(self: Shaped[MyClass, "*#batches"]) -> Shaped[MyClass, "*#batches"]:
        pass

Error message:

 ❯ python test.py                                                                                                                                                                                                                                                                                                      
Traceback (most recent call last):
  File "test.py", line 233, in <module>
    class MyClass:
  File "est.py", line 237, in MyClass
    def foo(self: Shaped[MyClass, "*#batches"]) -> Shaped[MyClass, "*#batches"]:
  File ".venv/lib/python3.10/site-packages/beartype/_decor/decorcache.py", line 74, in beartype
    return beartype_object(obj, conf)
  File ".venv/lib/python3.10/site-packages/beartype/_decor/decorcore.py", line 87, in beartype_object
    _beartype_object_fatal(obj, conf=conf, **kwargs)
  File ".venv/lib/python3.10/site-packages/beartype/_decor/decorcore.py", line 137, in _beartype_object_fatal
    beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
  File ".venv/lib/python3.10/site-packages/beartype/_decor/_nontype/decornontype.py", line 301, in beartype_nontype
    return beartype_func(obj, **kwargs)  # type: ignore[return-value]
  File ".venv/lib/python3.10/site-packages/beartype/_decor/_nontype/decornontype.py", line 389, in beartype_func
    func_wrapper_code = generate_code(decor_meta)
  File ".venv/lib/python3.10/site-packages/beartype/_decor/wrap/wrapmain.py", line 122, in generate_code
    code_check_params = _code_check_args(decor_meta)
  File ".venv/lib/python3.10/site-packages/beartype/_decor/wrap/_wrapargs.py", line 430, in code_check_args
    reraise_exception_placeholder(
  File ".venv/lib/python3.10/site-packages/beartype/_util/error/utilerrraise.py", line 137, in reraise_exception_placeholder
    raise exception.with_traceback(exception.__traceback__)
  File ".venv/lib/python3.10/site-packages/beartype/_decor/wrap/_wrapargs.py", line 269, in code_check_args
    hint_or_sane = sanify_hint_root_func(
  File ".venv/lib/python3.10/site-packages/beartype/_check/convert/convsanify.py", line 200, in sanify_hint_root_func
    hint_or_sane = reduce_hint(
  File ".venv/lib/python3.10/site-packages/beartype/_check/convert/_reduce/redhint.py", line 376, in reduce_hint
    hint = _reduce_hint_uncached(
  File ".venv/lib/python3.10/site-packages/beartype/_check/convert/_reduce/redhint.py", line 600, in _reduce_hint_uncached
    die_unless_hint(hint=hint, exception_prefix=exception_prefix)
  File ".venv/lib/python3.10/site-packages/beartype/_util/hint/utilhinttest.py", line 106, in die_unless_hint
    die_unless_hint_nonpep(hint=hint, exception_prefix=exception_prefix)
  File ".venv/lib/python3.10/site-packages/beartype/_util/hint/nonpep/utilnonpeptest.py", line 181, in die_unless_hint_nonpep
    die_unless_hint_nonpep_type(
  File ".venv/lib/python3.10/site-packages/beartype/_util/hint/nonpep/utilnonpeptest.py", line 275, in die_unless_hint_nonpep_type
    die_unless_type_isinstanceable(
  File ".venv/lib/python3.10/site-packages/beartype/_util/cls/pep/clspep3119.py", line 277, in die_unless_type_isinstanceable
    isinstance(None, cls)  # type: ignore[arg-type]
  File ".venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 180, in __instancecheck__
    return cls.__instancecheck_str__(obj) == ""
  File ".venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 189, in __instancecheck_str__
    if not isinstance(obj, cls.array_type):
  File ".venv/lib/python3.10/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
    return cls.__is_instance_beartype__(obj)
  File ".venv/lib/python3.10/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
    return isinstance(obj, cls.__type_beartype__)  # type: ignore[arg-type]
  File ".venv/lib/python3.10/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
    referent = import_module_attr(
  File ".venv/lib/python3.10/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
    raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "MyClass" unimportable from module "__main__".

sean-roelofs-ai avatar Aug 25 '25 20:08 sean-roelofs-ai

@patrick-kidger: Fellow QA wizard, I summon thee! Firstly, apologies for not properly finishing up feature request #544 ("Official jaxtyping integration") with you. I have excuses. They are bad excuses – yet, they are mine. As the Python 3.14 release date looms, I've been frantically patching up the @beartype codebase to support all PEP 649 and 749 edge cases. Interestingly, that tangentially relates to this issue, because...

@sean-roelofs-ai: Thanks so much for bringing our mutual attention to this fascinating horror show. The central issue is this:

from __future__ import annotations

That's PEP 563. And... that's now deprecated. PEP 749 officially deprecated PEP 563 a few months ago:

Sometime after the last release that did not support PEP 649 semantics (expected to be 3.13) reaches its end-of-life, from __future__ import annotations is deprecated. Compiling any code that uses the future import will emit a DeprecationWarning. This will happen no sooner than the first release after Python 3.13 reaches its end-of-life, but the community may decide to wait longer. After at least two releases, the future import is removed, and annotations are always evaluated as per PEP 649. Code that continues to use the future import will raise a SyntaxError, similar to any other undefined future import.

tl;dr: On October ~15th 2029, from __future__ import annotations will be officially deprecated. On October ~15th 2031, from __future__ import annotations will be removed entirely from the Python language. At that time, any module using from __future__ import annotations will raise a SyntaxError at importation time and thus become unimportable. In 2025, nobody should enable from __future__ import annotations voluntarily.

In other words, your code is a ticking time bomb. Take it from a balding, malding middle-aged dude: six years will be here faster than anyone would like to believe. It happened to me. It can happen to your code, too. 😬

Understandably, neither @beartype or jaxtyping are likely to devote scarce resources to a deprecated standard that is about to destroy everyone's code. Instead...

PEP 649 + 749: They Exist, But Only Under Python ≥ 3.14

You now have two equally grimdark choices:

  • Just use stringified forward references instead.
  • Continue to use unquoted forward references without from __future__ import annotations by requiring Python ≥ 3.14.

Does @beartype + jaxtyping actually support either of these two choices? Technically, @beartype already supports both. Pragmatically, that doesn't mean that the combo of @beartype + jaxtyping does. When two things combine, they often explode.

Let's roll up our sleeves and get dirty with some minimal-length examples. First up, stringified forward references:

from jaxtyping import Shaped, jaxtyped
from beartype import beartype


class MyClass:
    
    @jaxtyped(typechecker=beartype) # <- this works
    # @beartype # <- this fails
    def foo(self: Shaped['MyClass', "*#batches"]) -> Shaped['MyClass', "*#batches"]:
        pass

...which raises:

Traceback (most recent call last):
  File "/home/leycec/tmp/mopy.py", line 7, in <module>
    class MyClass:
    ...<5 lines>...
            pass
  File "/home/leycec/tmp/mopy.py", line 9, in MyClass
    @jaxtyped(typechecker=beartype) # <- this works
     ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/site-packages/jaxtyping/_decorator.py", line 393, in jaxtyped
    full_signature = inspect.signature(fn)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 3312, in signature
    return Signature.from_callable(obj, follow_wrapped=follow_wrapped,
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   globals=globals, locals=locals, eval_str=eval_str,
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   annotation_format=annotation_format)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 3027, in from_callable
    return _signature_from_callable(obj, sigcls=cls,
                                    follow_wrapper_chains=follow_wrapped,
                                    globals=globals, locals=locals, eval_str=eval_str,
                                    annotation_format=annotation_format)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 2502, in _signature_from_callable
    return _signature_from_function(sigcls, obj,
                                    skip_bound_arg=skip_bound_arg,
                                    globals=globals, locals=locals, eval_str=eval_str,
                                    annotation_format=annotation_format)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 2325, in _signature_from_function
    annotations = get_annotations(func, globals=globals, locals=locals, eval_str=eval_str,
                                  format=annotation_format)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/annotationlib.py", line 895, in get_annotations
    ann = _get_dunder_annotations(obj)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/annotationlib.py", line 1063, in _get_dunder_annotations
    ann = getattr(obj, "__annotations__", None)
  File "/home/leycec/tmp/mopy.py", line 12, in __annotate__
    def foo(self: Shaped['MyClass', "*#batches"]) -> Shaped['MyClass', "*#batches"]:
                  ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/site-packages/jaxtyping/_array_types.py", line 663, in __getitem__
    out = _make_array(array_type, dim_str, cls)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/site-packages/jaxtyping/_array_types.py", line 600, in _make_array
    out = _make_array_cached(x, dim_str, dtype.dtypes, dtype.__name__)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/site-packages/jaxtyping/_array_types.py", line 563, in _make_array_cached
    if array_type is not Any and issubclass(array_type, AbstractArray):
                                 ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: issubclass() arg 1 must be a class

That's... not good. jaxtyping doesn't like stringified forward references, huh? Next up, let's try unquoted forward references via PEP 649 instead under Python ≥ 3.14:

from jaxtyping import Shaped, jaxtyped
from beartype import beartype


class MyClass:
    
    @jaxtyped(typechecker=beartype) # <- this works
    # @beartype # <- this fails
    def foo(self: Shaped[MyClass, "*#batches"]) -> Shaped[MyClass, "*#batches"]:
        pass

...which raises:

Traceback (most recent call last):
  File "/home/leycec/tmp/mopy.py", line 7, in <module>
    class MyClass:
    ...<5 lines>...
            pass
  File "/home/leycec/tmp/mopy.py", line 9, in MyClass
    @jaxtyped(typechecker=beartype) # <- this works
     ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/site-packages/jaxtyping/_decorator.py", line 393, in jaxtyped
    full_signature = inspect.signature(fn)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 3312, in signature
    return Signature.from_callable(obj, follow_wrapped=follow_wrapped,
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   globals=globals, locals=locals, eval_str=eval_str,
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   annotation_format=annotation_format)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 3027, in from_callable
    return _signature_from_callable(obj, sigcls=cls,
                                    follow_wrapper_chains=follow_wrapped,
                                    globals=globals, locals=locals, eval_str=eval_str,
                                    annotation_format=annotation_format)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 2502, in _signature_from_callable
    return _signature_from_function(sigcls, obj,
                                    skip_bound_arg=skip_bound_arg,
                                    globals=globals, locals=locals, eval_str=eval_str,
                                    annotation_format=annotation_format)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/inspect.py", line 2325, in _signature_from_function
    annotations = get_annotations(func, globals=globals, locals=locals, eval_str=eval_str,
                                  format=annotation_format)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/annotationlib.py", line 895, in get_annotations
    ann = _get_dunder_annotations(obj)
  File "/home/leycec/py/pyenv/versions/3.14.0b3/lib/python3.14/annotationlib.py", line 1063, in _get_dunder_annotations
    ann = getattr(obj, "__annotations__", None)
  File "/home/leycec/tmp/mopy.py", line 11, in __annotate__
    def foo(self: Shaped[MyClass, "*#batches"]) -> Shaped[MyClass, "*#batches"]:
                         ^^^^^^^
NameError: name 'MyClass' is not defined

That's... also super-not good! jaxtyping doesn't like PEP 649, either. But PEP 649 is the future of Python. Thus...

jaxtyping and @beartype Now Have a Problem

Curiously, @patrick-kidger and I were just spontaneously chatting over at #544 about how to more closely integrate @beartype and jaxtyping. We came to no good answers. jaxtyping has special needs and @beartype has no idea how to meet those special needs within the known lifetime of the Universe.

That said, if we solve #544, this issue should spontaneously just "go away" for jaxtyping. Why? Because, if jaxtyping just allows @beartype to do all the core heavy engine lifting for PEP-compliant type hint support, then jaxtyping no longer needs to worry about mucky, icky stuff like stringified type hints, PEP 649, or PEP 749. @beartype will transparently handle all of that on behalf of jaxtyping.

2025: it do be like that. 😯

leycec avatar Aug 26 '25 03:08 leycec

Ah, I think I can explain this one. This is way out in the weeds of how jaxtyping works.

Namely, we actually dynamically synthesise a function with the same type signature as the underlying function, and then decorate that function with the typechecker (i.e. beartype or otherwise), and then call that new function with the arguments provided at runtime.

In particular what this means is that the ForwardRefs are not resolveable because this dynamically-synthesised function does not have the correct __globals__ dictionary for looking them up in.

That whole thing probably sounds a bit crazy, and that's because it is. This dynamically-synthesised-function business is there to support a few things:

  • The ability to seperate typechecker errors from errors in the body of the user code.
  • On typechecking errors only, to provide additional jaxtyping contextual information about the values that have bound to each axis.
  • On typechecking errors only, to determine specifically which argument caused the error (we call a newly synthesised function for each argument in turn and see which fails) so as to offer a more useful error message than at least typeguard provides. (Noting that at least right now typeguard is still very popular for use with jaxtyping, due to its O(n) checking.)

In terms of fixing this, the best would probably be to find a way to have the jaxtyping-synthesised functions use the correct __globals__ dictionary. This isn't assignable at runtime, though, and I'm pretty hesitant to use the unstable types.FunctionType(...) constructor!

(Well the real fix would be to streamline the entire pile-of-hacks that is runtime type checking, but I'll have to wait for another life to get the time for that.)

patrick-kidger avatar Aug 31 '25 12:08 patrick-kidger