jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Problem with symbolic shapes

Open sbodenstein opened this issue 5 months ago • 4 comments

This

import jax
from jax import export
from jax import numpy as jnp
import jaxtyping as jt
import typeguard

@jt.jaxtyped(typechecker=typeguard.typechecked)
def f(
    x: jt.Float[jt.Array, "*#B"],
) -> jt.Float[jt.Array, "*#B"]:
  return x * jnp.sum(x) ** 2

dtype = jnp.float32
    x_shape = export.symbolic_shape("b")
    export.export(jax.jit(f))(
        jax.ShapeDtypeStruct(x_shape, dtype)
    )

fails with

jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.f.
Actual value: f32[b](jax)
Expected type: Float[Array, '*#B'].
----------------------
Called with parameters: {'x': f32[b](jax)}
Parameter annotations: (x: Float[Array, '*#B']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
B=(b,)

The problem seems to be *# (it works without either).

sbodenstein avatar Aug 12 '25 10:08 sbodenstein

Thanks for the report! (Though I'll note that your google3 usage is not reproducible in the outside world ;) ) Trimming that out, the error ultimately is raised from this line:

https://github.com/patrick-kidger/jaxtyping/blob/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/jaxtyping/_array_types.py#L298

because naturally jaxtyping/numpy does not know how to broadcast JAX's symbolic shapes against each other.

These kinds of symbolic shapes have been floating around in JAX for a while, but from the fact that you're using them then I'm guessing that they've graduated into public API surfaces. At least so far we've only supported shapes that are tuples-of-integers.

Unfortunately I don't think this is a thing we can reasonably support in jaxtyping.

  1. This is the kind of thing that tends to be pretty unstable in JAX, and I don't think I want to try and stay on top of whatever future changes happen here.
  2. The name 'jaxtyping' is now historical, and we aim to support all array/tensor libraries on roughly equal footing. Baking in JAXisms may complicate our ability to support other libraries as well.

patrick-kidger avatar Aug 12 '25 11:08 patrick-kidger

Thank you for the explanation!

How about supporting a mode that ignores shapes when they are symbolic? That could get users most of the way.

sbodenstein avatar Aug 21 '25 16:08 sbodenstein

So a few posible options:

  • it's possible to disable jaxtyping (jaxtyping.config.update("jaxtyping_disable", True)). You could do this prior to the symbolic section of your code.
  • it's very common to only use jaxtyping during tests. Or put another way, another option is to simply not use jaxtyping alongside symbolic shapes.
  • if you did want to avoid specifically the case in which just one array has a symbolic shape, then you could wrap jaxtyping in some kind of Foo[Float[Array, "x y z"]] such that type(Foo).__instancecheck__ delegates to the wrapped hint if and only if the input does not have a symbolic shape.

patrick-kidger avatar Aug 22 '25 10:08 patrick-kidger

@patrick-kidger: thanks! Those are indeed all reasonable options for working around this.

sbodenstein avatar Sep 01 '25 15:09 sbodenstein