jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Is it possible to be used with `tensordict`?

Open RuiWang1998 opened this issue 1 year ago • 3 comments

I find jaxtyping a blast to use, as is tensordict(https://github.com/pytorch/tensordict). If I could have it both ways, that would be even more amazing!

Thanks for the great work!

RuiWang1998 avatar Nov 05 '24 04:11 RuiWang1998

Hmm, so it's not clear to me how we'd go about annotating a TensorDict, in the same way that it's not obvious to me how to annotate the shapes and dtypes of the values of a regular dict.

But... have you tried the @tensorclass decorator that tensordict also provides? As that explicitly annotates the contained elements then I think that should be exactly what you need?

patrick-kidger avatar Nov 05 '24 07:11 patrick-kidger

Can confirm that tensorclass works

@tensorclass
class CameraIntrinsics:
    fx: Float[Tensor, " *num_cameras"]
    fy: Float[Tensor, " *num_cameras"]
    cx: Float[Tensor, " *num_cameras"]
    cy: Float[Tensor, " *num_cameras"]
...    
def some_other_fun(intrinsics: Shaped[CameraIntrinsics, " *num_cameras"])

Mxbonn avatar Nov 06 '24 09:11 Mxbonn

Great! Thanks for the advice!

RuiWang1998 avatar Nov 06 '24 11:11 RuiWang1998