Strange interaction between Shaped and Union
The jaxtyping.Shaped / Float / Int etc. interact strangely with typing.Union:
from typing import Union
from jaxtyping import Array, Shaped
import jax.numpy as jnp
x = jnp.zeros([3])
# These all work:
assert isinstance(x, Array)
assert isinstance(x, Union[Array, int])
assert isinstance(x, Shaped[Array, "_"])
# But this one fails:
assert isinstance(x, Union[Shaped[Array, "_"], int]) # << AssertionError
# interestingly this one works:
assert isinstance(x, Shaped[Array, "_"] | int)
My usecase was to define an alias that accepts both jax.Array and np.ndarray, which I first tried without luck like so:
assert isinstance(x, Shaped[Union[Array, np.ndarray], "_"]) # << AssertionError
assert isinstance(x, Shaped[Array | np.ndarray, "_"])
# TypeError: type 'types.UnionType' is not an acceptable base type
For now I can use Shaped[Array, "_"] | Shaped[np.ndarray, "_"], but this behavior was very surprising to me and seems like a bug.
Oh, this is a mess of Python typing misfeatures and bugs.
So Union[...] types are only sort-of compatible with isinstance checks. In earlier versions of Python they simply errored out; in later versions of python then they just delegate to using type and issubclass, in blatant disregard for custom __instancecheck__ methods, or the fact that "non-faithful" types exist (those types y for which isinstance(x, y) does not imply issubclass(type(x), y), e,g, jaxtyping types for which y carries additional metadata about the shape and dtype).
There's no good reason for this. Union was originally introduced for static type checking, and runtime considerations have been almost entirely ignored by the Python standards commitees.
To further complication life, in Python 3.10 then A | B produces an object with type types.UnionType, whilst Union[A, B]. produces a typing._UnionGenericAlias object.
Once again, there is no good reason for this discrepancy: runtime considerations have been almost entirely ignored by the standards committee.
And so for some weird reason, as A | B is a different object to Union[A, B], here we find that the former is compatible with isinstance checks in the way you're doing (which is discussed in the relevant PEP), but the latter is not. (Which is a detail that goes entirely undiscussed in the PEP.)
And, once again! There is no good reason for this. The standards commitee have not given any thought to runtime considerations. Yes, this is maddening.
Anyway. what can we do about it? This:
assert isinstance(x, Shaped[Union[Array, np.ndarray], "_"])
fails because of the diversion via type and issubclass. The fix for this would be to upstream a fix to Python itself.
Meanwhile this:
assert isinstance(x, Shaped[Array | np.ndarray, "_"])
fails because A | B is a different type to Union[A, B] -- an odd detail that I didn't know about until now -- and so this check in jaxtyping for union handling does not trigger. The fix for this would be to adjust jaxtyping to look for typing.Union -- and on later versions of Python only, also types.UnionType.
I would be happy to accept a PR on this.
Meanwhile, if you want some code that will work today, and is version-compatible across all jaxtyping and Python versions, then you can unpack a Union like so:
x = jnp.zeros([3])
y = Shaped[Union[Array, np.ndarray], "_"]
assert isinstance(x, typing.get_args(y))
The crash with Shaped[A | B, ...] should be fixed in #77 (version 0.2.15).
is there plans to support A | B the same way Union[A, B] works for jaxtyping? since we're on py 3.12, an aggressively linter tries to move all Union[A, B] to A | B, not realizing there's a subtle difference.
jaxtyping therefore doesn't fail despite the wrong type if given A | B, and only fails with Union[A, B]
I apologize if I missed an issue in supporting the pipe at same parity as Union.
if it's only a matter of someone working on this, I'm happy to try
A | B should already be supported, as per #77. You appear to be linking to a commit from about 2 years ago, prior to this fix :)