jaxtyping
jaxtyping copied to clipboard
Jaxtyping a class with mutable shapes
I want to typecheck a class (typedict in the below example) with the contract that at any instant, all shape variables (eg. B, T, H, W) are the same for all tensors in the class but transforms such as crop_sample can modify the class. Is there a way to rebind H, W after such shape altering operations?
from typing import TypedDict
import torch
from jaxtyping import Float32, jaxtyped
from typeguard import typechecked
@jaxtyped(typechecker=typechecked)
class MyDict (TypedDict, total=False):
foo1: Float32[torch.Tensor, "B T 3 H W"]
foo2: Float32[torch.Tensor, "B T 3 H W"]
baz: Float32[torch.Tensor, "B 1 4 4"]
@jaxtyped(typechecker=typechecked)
def crop_sample(dict: MyDict) -> MyDict:
# Ensure we modify all tensors with the same crop
h_start, w_start = 50, 50
foo1 = dict["foo1"][:, :, :, h_start:, w_start:]
foo2 = dict["foo2"][:, :, :, h_start:, w_start:]
dict["foo1"] = foo1
dict["foo2"] = foo2
return dict
if __name__ == "__main__":
my_dict = MyDict(foo1=torch.randn(1, 1, 3, 100, 100), foo2=torch.randn(1, 1, 3, 100, 100))
print(my_dict["foo1"].shape)
my_dict = crop_sample(my_dict)
print(my_dict["foo1"].shape)
I'm afraid there isn't. The intended interpretation of this is to check that the shapes are the same.