jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Is it possible to support tuple unpacking?

Open antony-frolov opened this issue 4 months ago • 6 comments

Hi! I'm trying to type something like this:

@jaxtyped(typechecker=beartype)
def forward(
    self, x: Float[torch.Tensor, "t dim"], *args: torch.Tensor
) -> tuple[torch.Tensor, "t dim"], *tuple[torch.Tensor, ...]]:
    return func(x), *args

As I've read here https://github.com/patrick-kidger/jaxtyping/blob/main/docs/faq.md PEP 646 is not supported so it seem's like this kind of return value typing is not supported as of now.

I've just found that the latest version of beartype supports this kind of unpacking, so maybe there's a chance of this being supported in jaxtyping?

typing return as tuple[torch.Tensor, *tuple[torch.Tensor, ...]] also doesn't work using @jaxtyped(typechecker=beartype) but works with just @beartype

antony-frolov avatar Aug 05 '25 16:08 antony-frolov

...heh. PEP 646-compliant tuple and type variable tuple unpacking, huh? The QA nightmare fuel that has consumed my every waking moment for the past three months, huh? The day has finally come when users care about PEP 646 – and it isn't a pretty day.

Firstly, though, this particular example doesn't seem quite right. This:

...
) -> tuple[torch.Tensor, "t dim"], *tuple[torch.Tensor, ...]]:

...isn't a valid PEP 646-compliant type hint. I think, anyway. Because PEP 646 is nightmare fuel, it's hard to say. But I'm pretty you can't directly comma-delimit type hints. They have to subscript a parent type hint. You probably instead meant:

...
) -> tuple[Float[torch.Tensor, "t dim"], *tuple[torch.Tensor, ...]]:

Something like that, maybe?

PEP 646: This Specification Was Built on Pain

@patrick-kidger: You're amazing. We know this. Because you're amazing, your time is valuable. Personally, I wouldn't advise trying to explicitly support or handle PEP 646 in jaxtyping. PEP 646 is a cesspit of hundreds of devious edge cases that intersect in unpleasant and unsavoury ways. Unsurprisingly, PEP 646 spans nearly 40 pages of dead tree paper. It's also kinda "all or nothing." You either gotta imelement all of PEP 646 or none of PEP 646. There's not much leeway for incremental solutions there.

Insane subsections of PEP 646 include my personal favourite "Splitting Arbitrary-length Tuples", which contains useful phrases that everyone likes to see in a formal specification:

  • "A final complication occurs when an unpacked arbitrary-length tuple is used as a type argument to an alias consisting of both TypeVars and a TypeVarTuple:"
  • "We assume the arbitrary-length tuple contains at least as many items as there are TypeVars, such that individual instances of the inner type - here int - are bound to any TypeVars present."
  • "The ‘rest’ of the arbitrary-length tuple - here *Tuple[int, ...], since a tuple of arbitrary length minus two items is still arbitrary-length - is bound to the TypeVarTuple."
  • "Of course, such splitting only occurs if necessary."
  • "In particularly awkward cases, a TypeVarTuple may consume both a type and a part of an arbitrary-length tuple type:"

A standard that says "A final complication occurs," "We assume," "only occurs if necessary," and "In particularly awkward cases," is a standard everyone should run from.

@beartype only supports PEP 646 because of autism. That's the reason – and it's not a great reason. I just wanted to solve the puzzle. I don't recommend that reason to others, though. Some puzzles are best left unsolved.

Just One Example of Eye-stinging Woe and Hardship

@antony-frolov's example above is a great example of the pain inherent in PEP 646. CPython currently provides no runtime API for detecting unpacked child tuple hints like *tuple[torch.Tensor, ...]. What's worse is that doing so manually is highly non-trivial. Unpacked child tuple hints do not have an unambiguous type. They have an ambiguous type shared by many different kinds of type hints.

@beartype thus hand-rolled its own ridiculous tester function:

def is_hint_pep646_tuple_unpacked_prefix(hint: Hint) -> bool:
    '''
    :data:`True` only if the passed hint is a :pep:`646`-compliant **prefix-based
    unpacked child tuple hint** (i.e., of the form "*tuple[{hint_child_child_1},
    ..., {hint_child_child_M}]" subscripting a parent tuple hint of the form
    "tuple[{hint_child_1}, ..., *tuple[{hint_child_child_1}, ...,
    {hint_child_child_M}], ..., {hint_child_N}]").

    If this tester returns :data:`True`, this unpacked child tuple hint is
    guaranteed to define the ``__args__`` dunder attribute to be either:

    * A 2-tuple ``({hint_child}, ...)``, in which case this child tuple hint
      unpacks to a variable-length tuple hint over ``{hint_child}`` types.
    * An n-tuple ``({hint_child_1}, ..., {hint_child_N})`` where ``...`` in this
      case is merely a placeholder connoting one or more child hints, in which
      case this child tuple hint unpacks to a fixed-length tuple hint over these
      exact ``{hint_child_I}`` types.

    This getter is intentionally *not* memoized (e.g., by the
    ``callable_cached`` decorator), as the implementation trivially reduces to
    an efficient one-liner.

    Motivation
    ----------
    Interestingly, even detecting accursed unpacked child tuple hints at runtime
    is highly non-trivial. They do *not* have a sane unambiguous type,
    significantly complicating detection. For some utterly inane reason, their
    type is simply the ambiguous type :class:`types.GenericAlias` (i.e.,
    :class:`.HintGenericSubscriptedType`). That... wasn't what we were expecting
    *at all*. For example, under Python 3.13:

    .. code-block:: python

       # Note that Python *REQUIRES* unpacked tuple type hints to be embedded in
       # some larger syntactic construct. So, just throw it into a list. This is
       # insane, because we're only going to rip it right back out of that list.
       # Blame the CPython interpreter. *shrug*
       >>> yam = [*tuple[int, str]]

       # Arbitrary unpacked tuple type hint.
       >>> yim = yam[0]
       >>> repr(yim)
       *tuple[int, str]  # <-- gud
       >>> type(yim)
       <class 'types.GenericAlias'>  # <-- *TOTALLY NOT GUD. WTF, PYTHON!?*

       # Now look at this special madness. The type of this object isn't even in
       # its method-resolution order (MRO)!?!? I've actually never seen that
       # before. The type of any object is *ALWAYS* the first item in its
       # method-resolution order (MRO), isn't it? I... guess not. *facepalm*
       >>> yim.__mro__
       (<class 'tuple'>, <class 'object'>)

       # So, "*tuple[int, str]" is literally both a tuple *AND* a "GenericAlias"
       # at the same time. That makes no sense, but here we are. What are the
       # contents of this unholy abomination?
       >>> dir(yim)
       ['__add__', '__args__', '__bases__', '__class__', '__class_getitem__',
       '__contains__', '__copy__', '__deepcopy__', '__delattr__', '__dir__',
       '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__',
       '__getitem__', '__getnewargs__', '__getstate__', '__gt__', '__hash__',
       '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__',
       '__mro_entries__', '__mul__', '__ne__', '__new__', '__origin__',
       '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__',
       '__setattr__', '__sizeof__', '__str__', '__subclasshook__',
       '__typing_unpacked_tuple_args__', '__unpacked__', 'count', 'index']

       # Lotsa weird stuff there, honestly. Let's poke around a bit.
       >>> yim.__args__
       (<class 'int'>, <class 'str'>)  # <-- gud
       >>> yim.__typing_unpacked_tuple_args__
       (<class 'int'>, <class 'str'>)  # <-- gud, albeit weird
       >>> yim.__unpacked__
       True  # <-- *WTF!?!? what the heck is this nonsense?*
       >>> yim.__origin__
       tuple  # <-- so you lie about everything, huh?

    Basically, the above means that the only means of reliably detecting an
    unpacked tuple hint at runtime is by heuristically introspecting for both
    the existence *and* values of various dunder attributes exhibited above.
    Introspecting merely the existence of these attributes is insufficient; only
    the combination of both existence and values suffices to effectively
    guarantee disambiguity. Likewise, introspecting merely one or even two of
    these attributes is insufficient; only the combination of three or more of
    these attributes suffices to effectively guarantee disambiguity.

    Note there *are* no means of actually guaranteeing disambiguity. Malicious
    third-party objects could attempt to masquerade as unpacked child tuple
    hints by defining similar dunder attributes. We can only reduce the
    likelihood of false positives by increasing the number of dunder attributes
    introspected by this tester. Don't blame us. We didn't start the fire.

    Parameters
    ----------
    hint : Hint
        Type hint to be inspected.

    Returns
    -------
    bool
        :data:`True` only if this hint is an unpacked child tuple hint.
    '''
    # print(f'dir({hint}): {dir(hint)}')
    # print(f'{hint}.__class__: {hint.__class__}')
    # print(f'{hint}.__args__: {getattr(hint, '__args__', False)}')
    # print(f'{hint}.__typing_unpacked_tuple_args__: {getattr(hint, '__typing_unpacked_tuple_args__', True)}')

    # Return true only if...
    return (
        # This hint's type is that of all unpacked child tuple hints (as well as
        # many other unrelated first- and third-party hints, sadly) *AND*...
        hint.__class__ is HintGenericSubscriptedType and
        (
            # The tuple of all child hints subscripting this parent hint is also
            # the tuple of all child hints subscripting this unpacked child
            # tuple hint. Since only unpacked child tuple hints *SHOULD* define
            # the PEP 646-compliant and frankly outrageously verbose
            # "__typing_unpacked_tuple_args__" dunder attribute, this
            # equivalence *SHOULD* suffice to disambiguate this hint as an
            # unpacked child tuple hint.
            getattr(hint, '__args__', False) ==
            getattr(hint, '__typing_unpacked_tuple_args__', True)
        )
        #FIXME: Probably unnecessary for the moment. Let's avoid violating
        #fragile privacy encapsulation any more than we must, please. *sigh*
        # # This hint's method-resolution order (MRO) is that of all unpacked
        # # child tuple hints *AND*...
        # getattr(hint, '__mro__', None) == _PEP646_HINT_TUPLE_UNPACKED_MRO and  # pyright: ignore
        # # This hint defines a dunder attribute uniquely defined *ONLY* by
        # # unpacked child tuple hints with a value guaranteed to be set by all
        # # unpacked child tuple hints.
        # getattr(hint, '__unpacked__', None) is True
    )

And, of course, that still fails to suffice for the general case. Why? Because typing.Unpack[tuple[torch.Tensor, ...]] is a completely different syntax that's semantically equivalent to *tuple[torch.Tensor, ...]. They mean the exact same thing. Yet, they're completely different kinds of runtime objects that share no similarities whatsoever.

Uhh... So What Are You Saying, Exactly?

@beartype already solved the problem. jaxtyping and @beartype are best friends. Rather than reinvent the wheel of thorns, two choices present themselves.

The first choice isn't great. In fact, it's awful. I present it for completeness. jaxtyping could conditionally defer to @beartype when @beartype is importable... regardless of whether the @beartype-specific @jaxtyped(typechecker=beartype) decorator is applied or not. @beartype currently has a number of private APIs to solve PEP 646 and similar QA decision problems. The problem is... they're all private at the moment. They also tend to change a lot. Like, a lot a lot. They're moving targets, because typing standards are moving targets.

The second choice is possibly great. Rather than publicize those moving targets as fragile APIs that will inevitably break jaxtyping, @beartype could instead begin integrating support for jaxtyping at a super-deep level. Specifically, @beartype could:

  1. Detect jaxtyping-specific type hints like jaxtyping.Float[torch.Tensor, "t dim"].
  2. When detected, automatically generate code validating that the corresponding tensors satisfy those jaxtyping-specific type hints – regardless of whether those tensors were originally specified as decorated parameters, decorated returns, PEP 526-compliant annotated variable assignments, die_if_unbearable() parameters, is_bearable() parameters, or whatevahs.

At that point, the @jaxtyped decorator and its associated import hook would no longer be needed when @beartype is used. Users could still use those if they like, but they'd (more or less) just reduce to a mildly expensive noop. Internally, @beartype would transparently detect and use jaxtyping and its associated machinery as needed.

Interestingly, we've actually broached this subject before over at #153. @avolchek (Andrei Volchek) discovered this insane kludge to monkey-patch jaxtyping and @beartype together in exactly this way: e.g.,

On detecting a callable annotated by one or more jaxtyping type hints, dynamically inject jaxtyping setup and teardown logic into the type-checking wrapper function that @beartype generates as follows:

import jaxtyping

def __beartype_wrapper(...):
    jaxtyping._storage.push_shape_memo({})

    ...  # <-- *BEARTYPE MAGIC HAPPENS HERE*
    __beartype_pith_0 = ...  # <-- *MORE MAGICAL UNICORNS ERUPT*

    jaxtyping._storage.pop_shape_memo()
    return __beartype_pith_0

That... looks super-trivial to me. So why haven't I or anyone else officially made this integrative magic happen then? Laziness and politeness. But mostly laziness. But also politeness.

I don't want @beartype to step on anyone's delicate toes or APIs. I applaud @patrick-kidger for his wondrous accomplishments here and everywhere else. He is amazing. Without jaxtyping, far fewer users would use @beartype. Because I respect and adore jaxtyping, @beartype shouldn't "muscle in" on jaxtyping territory with the explicit approval of everyone involved.

Also, jaxtyping._storage is yet another fragile private API. @beartype probably shouldn't be violating that privacy encapsulation. Perhaps the relevant push_shape_memo() and pop_shape_memo() methods could be publicly exposed to @beartype at some point? You wouldn't even need to document them, really. Only @beartype – and possibly typeguard as well – would ever be expected to call those bizarro-world methods.

William Riker synopsizes my feelings on this topic.

@beartype has a beard like this, too

leycec avatar Aug 05 '25 21:08 leycec

Firstly, though, this particular example doesn't seem quite right. This:

... ) -> tuple[torch.Tensor, "t dim"], *tuple[torch.Tensor, ...]]: ...isn't a valid PEP 646-compliant type hint. I think, anyway. Because PEP 646 is nightmare fuel, it's hard to say. But I'm pretty you can't directly comma-delimit type hints. They have to subscript a parent type hint. You probably instead meant:

... ) -> tuple[Float[torch.Tensor, "t dim"], *tuple[torch.Tensor, ...]]: Something like that, maybe?

Oh, yes, that's exactly what i wanted to write, sorry for the confusion. And thanks for the detailed comment! Not sure I'll be able to meaningfully contribute to further discussion, but will be happy to see it unravel!

antony-frolov avatar Aug 05 '25 21:08 antony-frolov

This discussion is pure gold 🤩 love the constructive rant.

I never heard of variable length tuple as a type annotation, and it seems kind of antithetic to the idea of type checking to me. Now I would vote to not support it 😅

But to suggest a pragmatic solution that is not a whatever-tuple, can you use @overload in your code, @antony-frolov? Then you can still type-check, but only for the cases that matter to you.

johannahaffner avatar Aug 05 '25 21:08 johannahaffner

it seems kind of antithetic to the idea of type checking to me.

True words have now been spoken.

Now I would vote to not support it 😅

I see your anti-vote and raise you another anti-vote. Indeed, PEP 646 kinda sucks. It's just an endless barrage of increasingly nonsensical edge cases, which then makes it super-rough for either devs or users to understand exactly what's going on and why.

It really seems like the PEP 646 authors lost sight of the forest for the trees. They meant well, but humans often do. The goal state for any typing standard should be to reduce bugs. That's the whole point of type-checking in the first place, right?

Instead, PEP 646 is so gargantuan and unintelligible that it only promotes primordial confusion and chaos. Which... is exactly what leads to bugs! 😮‍💨

Everybody Hates PEP 646. What's the Problem Then?

The problem is that PEP 646 will never go away. It's like the Terminator of QA standards. The explicit use case of PEP 646 is literally jaxtyping: tensor typing. This means that jaxtyping users will always want PEP 646 to happen.

For example, this is literally the first example in PEP 646:

from typing import NewType, TypeVar, TypeVarTuple
from typing import Literal as L

DType = TypeVar('DType')
Shape = TypeVarTuple('Shape')

class Array(Generic[DType, *Shape]):
    def __abs__(self) -> Array[DType, *Shape]: ...
    def __add__(self, other: Array[DType, *Shape]) -> Array[DType, *Shape]: ...

Height = NewType('Height', int)
Width = NewType('Width', int)

x: Array[float, Height, Width] = Array()
y: Array[float, L[480], L[640]] = Array()

Users like that kinda of syntax. It's Pythonic. It doesn't require stringified attribute names like Float[torch.Tensor, "t dim"], which looks vaguely suspicious to many users.

Users thus love PEP 646 – even if they don't particularly understand it. After all, nobody understands PEP 646. But that's fine. PEP 646 syntax is Pythonic to read and type, which is all most users care about. "Understanding" comes a distant second (if it rates at all).

If jaxtyping doesn't get there first, then another third-party tensor typing API will. My vote is for jaxtyping and @beartype to become even closer friends. That's how jaxtyping gets there first.

@beartype + jaxtyping = 💪 🐻

leycec avatar Aug 06 '25 05:08 leycec

Ok, so! A couple of different things going on here now.


First of all, I'm actually a little bit surprised that this doesn't work already! When jaxtyping uses an existing typechecker, it is to defer the type-parsing-and-traversing to that typechecker. Our contribution to this is pretty much just to provide a type that can be isinstanced against.

Do you have a code snippet that can be copy-pasted to demonstrate the problem? (Including imports etc)


Now this has brought up a second point about jaxtyping-beartype integration! I think this would be pretty awesome and I am in favor of this. As @leycec notes, jaxtyping and beartype are best friends.

In terms of what that would mean, I think it might be slightly more work than as described above 😅 Namely:

  • A way to pass state in-and-out of each __instancecheck__ method. Right now jaxtyping has this threadlocal stack and it's responsible for all kinds of nightmare fuel.
  • A way to offer string error messages describing why failure occurred. Oh wait, we already have __instancecheck_str__, so we're good to go there!
  • For full-fat O(n) typechecking. Yup, this is the big one. It's pretty much essential for many of our users so absent this it doesn't make much sense to do more.

With these pieces in place then beartype could call into jaxtyping via essentially checking hasattr(hint, "__instancecheck_str_stateful__") and using it if so. In particular you can see that I'm definitely aiming for a clear protocol for jaxtyping<>beartype... in large part because I want to get of this threadlocal insanityness, not bake it in further 😄

patrick-kidger avatar Aug 06 '25 07:08 patrick-kidger

Yes! Let's do this. Please meet me over at beartype/beartype#544 for a feature request throw-down. Let's spec some integrations out.

As for this issue... I'd probably close it. PEP 646 is out of scope for everybody except Pydantic, @beartype, and typeguard. We have to support absurd standards like PEP 646, because we have no choice; that's what we voluntarily signed up for when we started heading down the path of madness.

jaxtyping, though? jaxtyping definitely should not have to care about PEP 646. That's way beyond the purview of "normal" packages. Thank your lucky stars, @patrick-kidger. We spared you some nightmare fuel.

leycec avatar Aug 08 '25 05:08 leycec