Splicing / variadic symbolic expressions
Would it be possible to make the following code snippet work?
import torch
from beartype import beartype
from jaxtyping import Float, jaxtyped
from torch import Tensor
class A:
def __init__(self, shape: tuple[int, ...]):
self.shape = shape
@jaxtyped(typechecker=beartype)
def forward(self, x: Float[Tensor, "... {self.shape}"]) -> Float[Tensor, "..."]:
return x.flatten(start_dim=-len(self.shape)).sum(dim=-1)
a = A((3, 10, 5))
x = torch.randn((7, 3, 4, 5))
print(a.forward(x))
At the moment it does not work as far as I can tell, because {self.shape} is only matched against a single dimension of x. Is there a way to evaluate the expression and splice in the tuple value into the type before the type gets matched against the dimensions? Maybe with something like a *{self.shape} syntax?
Yup, this is a known issue. I don't have a nice way to fix this right now -- this is quite a complicated corner of jaxtyping! -- but I'd be happy to take a PR if someone feels like taking this on.
If need be you can maybe do something like str(self.shape).replace(",", " ")[1:-1] but that's obviously pretty messy.