jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Splicing / variadic symbolic expressions

Open martenlienen opened this issue 1 year ago • 1 comments

Would it be possible to make the following code snippet work?

import torch
from beartype import beartype
from jaxtyping import Float, jaxtyped
from torch import Tensor

class A:
    def __init__(self, shape: tuple[int, ...]):
        self.shape = shape

    @jaxtyped(typechecker=beartype)
    def forward(self, x: Float[Tensor, "... {self.shape}"]) -> Float[Tensor, "..."]:
        return x.flatten(start_dim=-len(self.shape)).sum(dim=-1)

a = A((3, 10, 5))
x = torch.randn((7, 3, 4, 5))
print(a.forward(x))

At the moment it does not work as far as I can tell, because {self.shape} is only matched against a single dimension of x. Is there a way to evaluate the expression and splice in the tuple value into the type before the type gets matched against the dimensions? Maybe with something like a *{self.shape} syntax?

martenlienen avatar Nov 12 '24 08:11 martenlienen

Yup, this is a known issue. I don't have a nice way to fix this right now -- this is quite a complicated corner of jaxtyping! -- but I'd be happy to take a PR if someone feels like taking this on.

If need be you can maybe do something like str(self.shape).replace(",", " ")[1:-1] but that's obviously pretty messy.

patrick-kidger avatar Nov 12 '24 09:11 patrick-kidger