jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Support for type checking dataclass properties

Open sbodenstein opened this issue 4 months ago • 2 comments

Suppose you have a dataclass:

@jaxtyped(typechecker=typechecker)
@dataclass
class MyDataclass:
    x: Float[Array, "n"]

  @property
  def plus1(self) -> Float[Array, "n"]:
     return self.x + 1.0

There is no way AFAIK how to enforce that "n" is derived from the class definition. Is there a way to enforce this? And is this a feature that would make sense? Note also symbolic shapes cannot be expressed with checking, eg. Float[Array, "n+1"] would fail as n is undefined in the function input.

sbodenstein avatar Sep 01 '25 16:09 sbodenstein

Yup, there is: Float[Array, "{self.x.shape[0]}"] should work. That is, an f-string-without-the-f. Such 'f strings' are evaluating by jaxtyping at runtime using the arguments provided to the function.

If that syntax seems a bit funky for you (why not just use n directly), the rationale is that jaxtyping axis names are considered local to the lexical context of the function in which they're defined. (= the autogenerated __init__ method for dataclass attributes.) Anything beyond that starts being too much 'spooky action at a distance' IMO.

patrick-kidger avatar Sep 05 '25 23:09 patrick-kidger

Ah great!

sbodenstein avatar Sep 08 '25 18:09 sbodenstein