jaxtyping
jaxtyping copied to clipboard
Are pytorch named tensors supported, like in torchtyping?
I'm afraid there is no special support for these. I don't really recommend using them -- PyTorch itself seems to have mostly given up on adding further support for them. So they're only supported on a small number of PyTorch operations anyway.
That means you cannot access the dimension names, right? For example, I would like to apply something across all dimensions named L. Thanks!
@dataclass
class Label:
coordinates: Float32[Tensor, "L 3"]
distances: Float32[Tensor, "L L"]