Support for type checking dataclass properties
Suppose you have a dataclass:
@jaxtyped(typechecker=typechecker)
@dataclass
class MyDataclass:
x: Float[Array, "n"]
@property
def plus1(self) -> Float[Array, "n"]:
return self.x + 1.0
There is no way AFAIK how to enforce that "n" is derived from the class definition. Is there a way to enforce this? And is this a feature that would make sense? Note also symbolic shapes cannot be expressed with checking, eg. Float[Array, "n+1"] would fail as n is undefined in the function input.
Yup, there is: Float[Array, "{self.x.shape[0]}"] should work. That is, an f-string-without-the-f. Such 'f strings' are evaluating by jaxtyping at runtime using the arguments provided to the function.
If that syntax seems a bit funky for you (why not just use n directly), the rationale is that jaxtyping axis names are considered local to the lexical context of the function in which they're defined. (= the autogenerated __init__ method for dataclass attributes.) Anything beyond that starts being too much 'spooky action at a distance' IMO.
Ah great!