jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Jaxtyping a class with mutable shapes

Open SConsul opened this issue 10 months ago • 1 comments

I want to typecheck a class (typedict in the below example) with the contract that at any instant, all shape variables (eg. B, T, H, W) are the same for all tensors in the class but transforms such as crop_sample can modify the class. Is there a way to rebind H, W after such shape altering operations?

from typing import TypedDict

import torch
from jaxtyping import Float32, jaxtyped
from typeguard import typechecked



@jaxtyped(typechecker=typechecked)
class MyDict (TypedDict, total=False):
    foo1: Float32[torch.Tensor, "B T 3 H W"]
    foo2: Float32[torch.Tensor, "B T 3 H W"]
    baz: Float32[torch.Tensor, "B 1 4 4"]


@jaxtyped(typechecker=typechecked)
def crop_sample(dict: MyDict) -> MyDict:
    # Ensure we modify all tensors with the same crop
    h_start, w_start = 50, 50
    foo1 = dict["foo1"][:, :, :, h_start:, w_start:]
    foo2 = dict["foo2"][:, :, :, h_start:, w_start:]
    dict["foo1"] = foo1
    dict["foo2"] = foo2
    return dict


if __name__ == "__main__":
    my_dict = MyDict(foo1=torch.randn(1, 1, 3, 100, 100), foo2=torch.randn(1, 1, 3, 100, 100))
    print(my_dict["foo1"].shape)
    my_dict = crop_sample(my_dict)
    print(my_dict["foo1"].shape)

SConsul avatar Feb 28 '25 04:02 SConsul

I'm afraid there isn't. The intended interpretation of this is to check that the shapes are the same.

patrick-kidger avatar Feb 28 '25 18:02 patrick-kidger