jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Identical jaxtyping types are not == (as of v0.2.35)

Open calbach opened this issue 6 months ago • 2 comments

Hello @patrick-kidger . I checked the release notes and didn't find it mentioned so I don't know whether it's an intentional behavioral change or not. I bisected the issue to release v0.2.35

from jaxtyping import Float, Array

# As of v0.2.35, returns False.
Float[Array, "..."] == Float[Array, "..."]

I only tested Float, but I'd guess this is true for the other types as well.

Here are some practical side effects when using jaxtyping:

  1. Comparing function signatures dynamically
from jaxtyping import Float, Array
import inspect

def f(x: Float[Array, "a b"]): ...
def g(x: Float[Array, "a b"]): ...

# Now fails
assert inspect.signature(f) == inspect.signature(g)
  1. Using overrides package
from jaxtyping import Float, Array
from overrides import override

class P:
  def f(self, x: Float[Array, "..."]): ...

class C(P):
  @override
  def f(self, x: Float[Array, "..."]): ...

# TypeError: `C.f: x must be a supertype of `<class 'jaxtyping.Float[Array, '...']'>` but is `<class 'jaxtyping.Float[Array, '...']'>`

calbach avatar Jun 17 '25 00:06 calbach

Hey CH! Good to hear from you.

So this is intended, whilst admittedly also being something I would like to fix if I can.

We're caught between two other constraints here:

  • First of all, hashing and equality of jaxtyping Float[...] objects has to be by identity, not by checking its attributes. This is due to handling deserialization, when some libraries (in particular cloudpickle) will actually use partially-initialised classes as a cache key. Which is a bit sketchy on their part, but as a practical matter a lot of folks use jaxtyping in multiprocessing contexts, and changing this behaviour leads to a crash.

    Correspondingly we might imagine that the solution is simply to cache the creation of jaxtyping objects (so that comparisons by identity exhibit the same semantics as comparisons by attribute), i.e. to add a functools.cache to this function call:

    https://github.com/patrick-kidger/jaxtyping/blob/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/jaxtyping/_array_types.py#L599

    However that then leads us to our second problem...

  • ...which is that generators would like to be typed as e.g. def my_generator() -> Iterator[Float[Array, "foo"]]. When the generator is instantiated is when the typechecking is performed. (We typecheck function calls.) However generators actually yield when called via next, and in this case we're actually in some other jaxtyping context (the mapping from axis names to axis sizes), probably the one for typechecking whatever function we're in when we call next:

    @jaxtyped(typechecker=typechecker)  # context 1
    def my_generator(x: Float[Array, "foo"]) -> Iterator[Float[Array, "foo"]]:
        yield x
    
    @jaxtyped(typechecker=typechecker)  # context 2
    def foo(x: Float[Array "foo"]):
        gen = my_generator(x[:5])    # The type annotations during instantiation are checked in 'context 1' and we have `foo=5`
        next(gen)    # But at this point we've dropped that context, and now the iterator is checked from within 'context 2' and we have `foo=<something else>`.
    

    this would lead to spurious errors. Correspondingly we actually disable shape-checking for generators:

    https://github.com/patrick-kidger/jaxtyping/blob/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/jaxtyping/_decorator.py#L351

    by setting a flag on the Float[...] object itself. However in particular, that means that we cannot cache this object! We'd be mutating the single shared object used everywhere that someone happend to use the same shape and dtype annotations.


I'd be very happy to take suggestions on a better solution to this.

patrick-kidger avatar Jun 17 '25 07:06 patrick-kidger

Thanks for the detailed explanation @patrick-kidger . It seems like immutable types would be the path of least resistance - perhaps you could consider externally tracking the generator type as transparent within the scope of use (e.g. with a ContextVar), in lieu of the mutation. Admittedly I don't fully understand the entrypoints / hooks jaxtyping has available so I don't know if that's feasible.

Either way, it's not a major issue for me at the moment.

calbach avatar Jun 17 '25 19:06 calbach