long-short-transformer
long-short-transformer copied to clipboard
Understanding the Input
Can you help me understand, what should be my input shape (x), because when I print the shape of x = torch.randint(0, 20000, (1, 1024)) It is (1, 1024) I have a input of 1024 tokens each of 2048 features The input of (B,1024,2048) is giving error
einops.EinopsError: Error while processing rearrange-reduction pattern "b n (h d) -> (b h) n d". Input tensor shape: torch.Size([1, 1024, 2048, 512]). Additional info: {'h': 8}.