jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Feature Request: Allow for unbound composite structures as return types

Open pdcook opened this issue 1 month ago • 2 comments

Suppose one has a commonly used type with a given structure that is type aliased, and another commonly used type that should have this structure as a prefix:


ModelType: TypeAlias = PyTree[Any, "M"]
ModelParams: TypeAlias = PyTree[Any, "M ..."]

This works well, except for in the (in my opinion) reasonable case of a argument-less function that involves this second type as a return type:

def get_params() -> ModelParams:
    ...

This will always throw a AnnotationError since the structure name "M" had not been seen in the scope of this function.

The obvious workaround is to manually write either PyTree[Any, "M ..."] or PyTree[Any] every single time, depending on if "M" can be bound or not. But this is exceedingly opaque to any user and is tedious for more complicated and real-world types.

It's possible there would be a downside to allowing for unbound structures in return types, but we already have this functionality for Array shapes:

def get_array() -> Num[Array, "d ..."]:
    ....

works perfectly fine, despite "d" not having been seen in the scope of the function.

For reference, in case it makes a difference, I'm using beartype as my typechecker.

pdcook avatar Nov 23 '25 22:11 pdcook

Hey there! So the difference is that:

  • in the case of the array, we can use uniquely determine what d is – it's always just a single dimension, however...
  • ...pytree structure annotations are all able to cover 'multiple levels'/'multiple dimensions' of the pytree structure.

Correspondingly when we see M ... then we can't know what M refers to unless it's already appeared before on its own. And if we allowed that, then a return type of the form tuple[PyTree[Any, "M ..."], PyTree[Any, "M ..."]] would require us to analyze their shared pytree structure to finds their mutual prefix... and the whole thing starts getting quite hairy. In the general case this probably looks something like parsing a regex.

For this reason I'm afraid this is likely to remain unsupported.

patrick-kidger avatar Nov 24 '25 10:11 patrick-kidger

That makes perfect sense. Thanks for the quick reply!

pdcook avatar Nov 24 '25 22:11 pdcook