Problem with symbolic shapes
This
import jax
from jax import export
from jax import numpy as jnp
import jaxtyping as jt
import typeguard
@jt.jaxtyped(typechecker=typeguard.typechecked)
def f(
x: jt.Float[jt.Array, "*#B"],
) -> jt.Float[jt.Array, "*#B"]:
return x * jnp.sum(x) ** 2
dtype = jnp.float32
x_shape = export.symbolic_shape("b")
export.export(jax.jit(f))(
jax.ShapeDtypeStruct(x_shape, dtype)
)
fails with
jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.f.
Actual value: f32[b](jax)
Expected type: Float[Array, '*#B'].
----------------------
Called with parameters: {'x': f32[b](jax)}
Parameter annotations: (x: Float[Array, '*#B']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
B=(b,)
The problem seems to be *# (it works without either).
Thanks for the report! (Though I'll note that your google3 usage is not reproducible in the outside world ;) ) Trimming that out, the error ultimately is raised from this line:
https://github.com/patrick-kidger/jaxtyping/blob/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/jaxtyping/_array_types.py#L298
because naturally jaxtyping/numpy does not know how to broadcast JAX's symbolic shapes against each other.
These kinds of symbolic shapes have been floating around in JAX for a while, but from the fact that you're using them then I'm guessing that they've graduated into public API surfaces. At least so far we've only supported shapes that are tuples-of-integers.
Unfortunately I don't think this is a thing we can reasonably support in jaxtyping.
- This is the kind of thing that tends to be pretty unstable in JAX, and I don't think I want to try and stay on top of whatever future changes happen here.
- The name 'jaxtyping' is now historical, and we aim to support all array/tensor libraries on roughly equal footing. Baking in JAXisms may complicate our ability to support other libraries as well.
Thank you for the explanation!
How about supporting a mode that ignores shapes when they are symbolic? That could get users most of the way.
So a few posible options:
- it's possible to disable jaxtyping (
jaxtyping.config.update("jaxtyping_disable", True)). You could do this prior to the symbolic section of your code. - it's very common to only use jaxtyping during tests. Or put another way, another option is to simply not use jaxtyping alongside symbolic shapes.
- if you did want to avoid specifically the case in which just one array has a symbolic shape, then you could wrap jaxtyping in some kind of
Foo[Float[Array, "x y z"]]such thattype(Foo).__instancecheck__delegates to the wrapped hint if and only if the input does not have a symbolic shape.
@patrick-kidger: thanks! Those are indeed all reasonable options for working around this.