rotary-embedding-torch
rotary-embedding-torch copied to clipboard
Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
Hi, Thank you for sharing this code with us. However, I was confused with the axial rotary embeddings in rotary_embedding_torch.py file. " elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq...
Two issues into one, as they seemingly come from the same function. Right now, if I `torch.jit.trace` a module that uses `rotate_queries_or_keys()`, I hit the following `TracerWarning`: ``` TracerWarning: Converting...
Is it possibly to easily use axial rotary embeddings with your x-transformers without having to disect the Attention module? At first glance it seems that there is no simple way...
Hi, thank you very much for this handy rotary embedding library. I encountered this runtime error when the rotary embedding was trying to read cached frequency at the second `loss.backward()`...
Fixing reference to parameter
Hi! I'm interested in using the rotary embeddings with `x_pos=True` so my transformer is length-extrapolable. However, I noticed the readme mentions this technique works only with autoregressive transformers. Is there...
In your demo code, dim of q is 64 while dim of RotaryEmbedding is 32. I checked the code, q with position index larger than 32 will not be rotate...
Hi, I am debugging an issue with [my model](https://github.com/thorinf/simple-diffusion-lm) not learning longer contexts. It could be countless things, but I wanted to check if there are required tricks, or best...
Hello, Thank you for the amazing work! I had a brief question, shouldn't `(n r)` in repeat be `(r n)` [here](https://github.com/lucidrains/rotary-embedding-torch/blob/783d17820ac1e75e918ae2128ab8bbcbe4985362/rotary_embedding_torch/rotary_embedding_torch.py#L277). As (r n)!=(n r), as `(r n)` would be...
I've haven't investigate this but the latest commit makes the MeshGPT tests fail & users get the error below: ``` File /usr/local/lib/python3.10/dist-packages/local_attention/transformer.py:152, in LocalMHA.forward(self, x, mask, attn_bias, cache, return_cache) 149...