jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Strange interaction between Shaped and Union

Open Qwlouse opened this issue 2 years ago • 4 comments

The jaxtyping.Shaped / Float / Int etc. interact strangely with typing.Union:

from typing import Union
from jaxtyping import Array, Shaped
import jax.numpy as jnp

x = jnp.zeros([3])

# These all work:
assert isinstance(x, Array)
assert isinstance(x, Union[Array, int])
assert isinstance(x, Shaped[Array, "_"])
# But this one fails:
assert isinstance(x, Union[Shaped[Array, "_"], int])  # <<  AssertionError

# interestingly this one works:
assert isinstance(x, Shaped[Array, "_"] | int)

My usecase was to define an alias that accepts both jax.Array and np.ndarray, which I first tried without luck like so:

assert isinstance(x, Shaped[Union[Array, np.ndarray], "_"])  # << AssertionError
assert isinstance(x, Shaped[Array | np.ndarray, "_"])
# TypeError: type 'types.UnionType' is not an acceptable base type

For now I can use Shaped[Array, "_"] | Shaped[np.ndarray, "_"], but this behavior was very surprising to me and seems like a bug.

Qwlouse avatar Apr 01 '23 13:04 Qwlouse

Oh, this is a mess of Python typing misfeatures and bugs.

So Union[...] types are only sort-of compatible with isinstance checks. In earlier versions of Python they simply errored out; in later versions of python then they just delegate to using type and issubclass, in blatant disregard for custom __instancecheck__ methods, or the fact that "non-faithful" types exist (those types y for which isinstance(x, y) does not imply issubclass(type(x), y), e,g, jaxtyping types for which y carries additional metadata about the shape and dtype).

There's no good reason for this. Union was originally introduced for static type checking, and runtime considerations have been almost entirely ignored by the Python standards commitees.

To further complication life, in Python 3.10 then A | B produces an object with type types.UnionType, whilst Union[A, B]. produces a typing._UnionGenericAlias object.

Once again, there is no good reason for this discrepancy: runtime considerations have been almost entirely ignored by the standards committee.

And so for some weird reason, as A | B is a different object to Union[A, B], here we find that the former is compatible with isinstance checks in the way you're doing (which is discussed in the relevant PEP), but the latter is not. (Which is a detail that goes entirely undiscussed in the PEP.)

And, once again! There is no good reason for this. The standards commitee have not given any thought to runtime considerations. Yes, this is maddening.


Anyway. what can we do about it? This:

assert isinstance(x, Shaped[Union[Array, np.ndarray], "_"])

fails because of the diversion via type and issubclass. The fix for this would be to upstream a fix to Python itself.

Meanwhile this:

assert isinstance(x, Shaped[Array | np.ndarray, "_"])

fails because A | B is a different type to Union[A, B] -- an odd detail that I didn't know about until now -- and so this check in jaxtyping for union handling does not trigger. The fix for this would be to adjust jaxtyping to look for typing.Union -- and on later versions of Python only, also types.UnionType.

I would be happy to accept a PR on this.

Meanwhile, if you want some code that will work today, and is version-compatible across all jaxtyping and Python versions, then you can unpack a Union like so:

x = jnp.zeros([3])
y = Shaped[Union[Array, np.ndarray], "_"]
assert isinstance(x, typing.get_args(y))

patrick-kidger avatar Apr 02 '23 20:04 patrick-kidger

The crash with Shaped[A | B, ...] should be fixed in #77 (version 0.2.15).

patrick-kidger avatar Apr 13 '23 18:04 patrick-kidger

is there plans to support A | B the same way Union[A, B] works for jaxtyping? since we're on py 3.12, an aggressively linter tries to move all Union[A, B] to A | B, not realizing there's a subtle difference.

this check in jaxtyping for union handling does not trigger

jaxtyping therefore doesn't fail despite the wrong type if given A | B, and only fails with Union[A, B]

I apologize if I missed an issue in supporting the pipe at same parity as Union.

if it's only a matter of someone working on this, I'm happy to try

yash-s20 avatar Oct 24 '25 18:10 yash-s20

A | B should already be supported, as per #77. You appear to be linking to a commit from about 2 years ago, prior to this fix :)

patrick-kidger avatar Oct 24 '25 20:10 patrick-kidger