rotary-embedding-torch
rotary-embedding-torch copied to clipboard
LieRE: Generalizing Rotary Position Encodings. Beats RoPE-mixed by large margin and is much faster (compute-wise)
Hi, @lucidrains !
There was a promising research published this month (vs. RoPE-mixed (#25) in March), the so-called LieRE positional encodings generalize the kv-vector rotation to any numbers of dimension (1D, 2D, 3D, etc....), and are much simpler than RoPE in formulation. More than that, they result in much better models accuracy and 25%+ faster training than either axial RoPE or RoPE-mixed. I think their paper was really underappreciated, and this approach will be revolutionary.
LieRE leads to marked improvements in performance (up to 6%), training efficiency (3.5x reduction), data efficiency (30%) compared to the baselines of RoFormer, DeiT III, RoPE-Mixed and Vision-Llama.
The paper is here https://arxiv.org/abs/2406.10322. The LieRE authors gave only the pseudocode for now, however it looks extremely simple.
It looks easy, but I'm a bit confused how to implement the block-diagonal skew matrix with minimal learnable components and structure preservation. (stack of n x 1D parameters + tril_indices + block matrix?) Also integrating block-sparse optimizations for fast rotations would be nice to have
@kabachuha new rotary embeddings research! thank you for this, will check it out!
My current understanding:
import torch
def flat_to_skew(x, liere_block_size, axes_length, spacial_dims):
A = torch.zeros(liere_block_size, liere_block_size, axes_length, spacial_dims).to(x.device)
for d in range(spacial_dims):
i, j = torch.tril_indices(liere_block_size, liere_block_size, offset=-1) # w/o diagonal
A[i, j, :, d] = x[:, :, d]
A[j, i, :, d] = -x[:, :, d] # skew
return A
class AttentionLiereRotator(torch.nn.Module):
def __init__(self, head_dim, liere_block_size, spacial_dims, axes_length, num_heads):
super().__init__()
assert head_dim % liere_block_size == 0 and liere_block_size <= head_dim
self.liere_block_size = liere_block_size
self.head_dim = head_dim
self.spacial_dims = spacial_dims
self.axes_length = axes_length
self.num_heads = num_heads
# trainable parameters (for skew matrices)
self.vars = torch.nn.ParameterList(
[torch.nn.Parameter(torch.randn([(liere_block_size*liere_block_size - liere_block_size)//2, axes_length, spacial_dims])) for _ in range(head_dim // liere_block_size)]
)
self.spacial_indices = torch.arange(0, axes_length).unsqueeze(1).repeat([1, self.spacial_dims])
def forward(self, x: torch.Tensor, matrices=None):
# x [B, X*Y*... (spacial dims), num_heads, head_dim (N * liere_block_size)]
x = x.view(*[x.shape[0], self.axes_length*self.spacial_dims, self.num_heads, self.head_dim]) # we need only the head for the matrix product
if matrices is None:
# precomputed matrices for easier computation
# the matrix product compute dimensions w/o batch are:
# if p = spacial_dims*axes_length; dim = head_dim // liere_block_size
# then result = [p, dim] * exp{[dim,dim,p]*[p]}
# -- from dimension reduction logic
matrices = [
flat_to_skew(v, self.liere_block_size, self.axes_length, self.spacial_dims).view(self.liere_block_size,self.liere_block_size,self.axes_length*self.spacial_dims) @ \
self.spacial_indices.to(x.device, dtype=x.dtype).view(self.spacial_dims*self.axes_length) for v in self.vars
]
# skew to rotation via exponent
matrices = [torch.linalg.matrix_exp(A.float()) for A in matrices]
# -- Fact: Matrix exponent of block diagonal matrix is also block diagonal consisting of matrix exponents of the blocks
# -- source https://math.stackexchange.com/questions/3836462/matrix-exponential-of-a-block-diagonal-matrix
# -- TODO: make it work with lower than fp32 precision (if possible in torch)
# stacking as bigger block diagonal matrix (returning to head_dim x head_dim), then sparsing
matrices = torch.block_diag(*matrices)
# batch
matrices = matrices.unsqueeze(0).repeat(x.shape[0],1,1)
# to sparse
matrices = matrices.to_sparse()
# rotating the vector through multiplication
# -- making head_dim first
x = x.permute(0, 3, 1, 2)
# NOTE: have to upcast x too because of `"bmm_sparse_cuda" not implemented for 'Half'`
with torch.autocast(device_type=str(x.device).split(':')[0] if not str(x.device).startswith('cpu') else 'cpu',enabled=False):
dtype_store = x.dtype
x = torch.bmm(matrices.float(), x.view(*[x.shape[0], self.head_dim, self.axes_length*self.spacial_dims*self.num_heads]).float())
x = x.view(*[x.shape[0], self.head_dim, self.axes_length*self.spacial_dims, self.num_heads]).permute(0, 2, 3, 1).to(dtype_store)
return x, matrices
UPD: fixed some formatting/reference mistakes and enabled matrix caching for using the same matrix for k/q rotation
UPD 2: made all matrix operations work
UPD 3: made it launchable with 1D LLaMA text generation
Training on a very toy example of shakespeare with the code above
(looks okayish, maybe it will look better when the model has more params)
@lucidrains were you able to look at the paper?
@tasansal no i haven't had the time, will take a look soon
Hello! Authors of the paper here. We're excited to see folks trying LieRE out.
Here's a minimal example of how we generated the skew symmetric matrices.
generator_raw_params = nn.Parameter(
torch.rand(
input_dimensionality,
head_dim,
head_dim,
) * 2 * math.pi
)
upper_triangle = (
torch.triu(generator_raw_params, diagonal=1)
)
skew_bases = upper_triangle - torch.transpose(upper_triangle, -1, -2)
in_basis_positions = (
positions.reshape(list(positions.shape) + [1] * 2) * skew_bases
)
generator_pos = torch.sum(in_basis_positions, dim=-3)
rotation = torch.matrix_exp(generator_pos.to(dtype=torch.float32)).to(dtype=positions.dtype)
And a longer code snippet (https://github.com/SophieOstmeier/LieRE_implementation):
It's very exciting to see someone using the method! Some notes that you may find interesting:
- We didn't use a sparse matrix representation and instead used the right tensor shape and broadcasting to get the same effect. The slowest configuration we tried was with block size 2 (GPUs don't like lots of small matrices). That said, using a full, dense matrix never really slowed things down a measurable amount as the runtime was dominated by the quadratic component of the attention in our experiments.
- We used the backbone in https://github.com/kentaroy47/vision-transformers-cifar10 for our experiments. We have noticed that the baseline numbers are slightly different when using the default configuration of x-transformers vs. vision-transformers-cifar10 (cls token, no patch norm, more ff dropout).
- The performance comparisons were for training time to hit a fixed accuracy. LieRE hits the same accuracies faster than the other methods. Inference time for the same-sized model should be about the same.
- We saw that LieRE helps more for larger models on CIFAR 100 (model size sweeps are expensive, so we didn't get a chance to sweep model sizes on the larger datasets). We hope to update the arxiv version with those experiments soon.
- The noncommutativity means that LieRE is able to encode both absolute position information and relative position information. How are you breaking up the text into batch elements?
Hi, great job. I am the author of Lrpe. I would like to ask how the author perceives the differences between LieRE and Lrpe. Let me briefly explain Lrpe here: Lrpe points out that the decomposable multiplicative relative position encoding $W_i$ is a unitary matrix, and derives that $W_t =P \Lambda^t P^{\mathbf{H}}$, where $P$ is a unitary matrix, $\Lambda_t$ is a diagonal matrix, and ${\mathbf{H}}$ is the conjugate transpose. Under this condition, we have:
$$ q_s^{\mathbf{H}} W_s^{\mathbf{H}} W_tk_t= q_s^{\mathbf{H}} P^{\mathbf{H}} \Lambda^{-s}PP^{\mathbf{H}} \Lambda^{t} P k_t = (q_sP \Lambda^s)^{\mathbf{H}}(k_tP \Lambda^t)^{\mathbf{H}}. $$
More details can be found in the paper. It would be greatly appreciated if we could discuss the performance and theoretical differences between LieRE and Lrpe.
@lucidrains Hi lucidrains, if you find Lrpe valuable, I can submit a pull request.
@kabachuha hey, just circling back to this and actually could find a use for this type of approach for another repository
will be running some experiments later today
do you have any updates on your end on how it compares with rotary?
i'll be honest, Lie groups / algebra stuff still goes over my head mostly haha
will continue watching videos online (thank you youtube) until it is beaten into me
going to close this, as even if i apply this technique somewhere, it won't be in this repository
@kabachuha @SophieOstmeier hey, so i had the chance to run an experiment this morning using LieRE on character level language modeling, and unfortunately seeing the same thing. LieRE has a strong start but then loses out to rotary
will probably stick with rotary a bit longer and wait for more follow up research on this approach
I'll run a longer experiment overnight, just doesn't seem like an immediate win yet
@kabachuha do let me know if you have any updated results on your end
just spent a few more hours with experiments, tried things like multi-headed generators, smaller rotation matrices, slowing the learning rate of the generator parameters.. but ultimately unable to beat rotary embeddings
i think i'll have to pass on this work for now
Thank you so much for taking the time and for trying the method! Did you try with images?
We didn't get a chance to get LieRE working with text yet and these experiments got me thinking.
I'm not sure that LieRE is as interesting for 1D inputs like text as it is for higher dimensional inputs. With one dimensional inputs the position transforms are all simultaneously diagonalizable since the eigenvectors for all the transformations are identical. From a capacity perspective, this might devolve into something similar to RoPE with learnable phases, since with a bit of algebra you could fold the other part of the transform into the key and query matrices. Let me try a pass at writing down my thinking. The exposition here is still a bit rough, so please bear with me, let me know what you think, and please point out if you think there is an error.
Recall that the position transform matrix is $\exp( Ai)$ for token in position $i$ (with $A$ a learnable skew symmetric matrix). We can diagonalize the matrix with the eigendecomposition (note the matrix $\Lambda$ is pure imaginary with complex conjugate eigenvalues). For convenience (this will be important later) we order the eigenvalues of $\Lambda$ so that these pairs are adjacent.
$$ \begin{align*} R_i &= \exp( Ai) \ &= \Sigma ^{-1} \exp (i \Lambda) \Sigma\ &= \Sigma T \exp (i \Lambda) \Sigma \textrm{; (since $\Sigma$ is orthogonal)}\ &= \Sigma ^{-1} \begin{pmatrix} \exp(i \lambda_1) & 0 & \cdots & 0 \ 0 & \exp(i \lambda_2) & \cdots & 0 \ \vdots & \vdots & \ddots & \vdots \ 0 & 0 & \cdots & \exp(i \lambda_n) \end{pmatrix} \Sigma \end{align*} $$
When we compute the attention inner product between tokens $i,j$, with embeddings (glossing over how the heads are handled) $x_i, x_j$ we get the following:
$$ \begin{align} < R_i K x_i, R_j Q x_j> &= x_i^T K^T R_i^T R_j Q x_j \ &= x_i^T K^T (\Sigma ^T \exp (i \Lambda) \Sigma)^T \Sigma ^T \exp (j \Lambda) \Sigma Q x_j \ &= x_i^T K^T \Sigma^T \exp (i \Lambda) \Sigma \Sigma ^T \exp (j \Lambda) \Sigma Q x_j \ &= x_i^T K^T \Sigma^T \exp (i \Lambda) \exp (j \Lambda) \Sigma Q x_j \end{align} $$
Next we use a little trick to let us work with all real matrices shortly:
$$ \begin{align*} R(\theta) &= \begin{pmatrix} \cos(\theta) & -\sin(\theta) \ \sin(\theta) & \cos(\theta) \end{pmatrix} \ &= \frac{1}{\sqrt{2} }\begin{pmatrix} 1 & 1 \ i & -i \end{pmatrix} \begin{pmatrix} e^{i\theta} & 0 \ 0 & e^{-i\theta} \end{pmatrix} \frac{1}{\sqrt{2} } \begin{pmatrix} -i & 1 \ i & 1 \end{pmatrix} \end{align*} $$
The $\Sigma$ can just be folded into the key and query matrices, since it does not depend on the token. Lets call $\Sigma^\prime$, repeating the matrix above as 2x2 block diagonals. We also add another multiplication by $\Sigma^\prime$, since it'll make it easier to draw the parallel against RoPE.
$$ \begin{align*} K' &:=\frac{1}{\sqrt 2} \begin{pmatrix} 1 & i & 0 & 0 & \cdots & 0 & 0 \ 1 & -i & 0 & 0 & \cdots & 0 & 0 \ 0 & 0 & 1 & i & \cdots & 0 & 0 \ 0 & 0 & 1 & -i & \cdots & 0 & 0 \ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \ 0 & 0 & 0 & 0 & \cdots & 1 & i \ 0 & 0 & 0 & 0 & \cdots & 1 & -i \end{pmatrix} \Sigma K\ Q' &= \frac{1}{\sqrt 2} \begin{pmatrix} 1 & i & 0 & 0 & \cdots & 0 & 0 \ 1 & -i & 0 & 0 & \cdots & 0 & 0 \ 0 & 0 & 1 & i & \cdots & 0 & 0 \ 0 & 0 & 1 & -i & \cdots & 0 & 0 \ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \ 0 & 0 & 0 & 0 & \cdots & 1 & i \ 0 & 0 & 0 & 0 & \cdots & 1 & -i \end{pmatrix} \Sigma Q \end{align*} $$
We can then continue the derivation above, using $\star$ to mean taking the conjugate transpose and that ${\Sigma^{\prime}}^{\star} {\Sigma^{\prime}} = I$.
$$ \begin{align} < R_i K x_i, R_j Q x_j> &= x_i^T K^T R_i^T R_j Q x_j \ \ldots &= x_i^T K^T \Sigma^T \exp (i \Lambda) \exp (j \Lambda) \Sigma Q x_j \ &= x_i^T K^T \Sigma^T {\Sigma^\prime}^{\star} {\Sigma^\prime} \exp (i \Lambda) \exp (j \Lambda) {\Sigma^\prime} {\Sigma^\prime}^\star \Sigma Q x_j \ &= x_i^T K'^\star {\Sigma^\prime}^{\star} \exp (i \Lambda) \exp (j \Lambda) \Sigma^\prime Q' x_j \end{align} $$
Now, consider $R_i = {\Sigma^\prime}^\star \exp (i \Lambda) \Sigma^\prime$, noting it is a block-diagonal matrix of 2d rotations--basically a RoPE matrix with learnable phases. Recall that $\star$ is equal to the transposing when we have real matrices. This gives us:
$$ x_i^T K'^\star {\Sigma^\prime}^{\star} \exp (i \Lambda) \exp (j \Lambda) \Sigma^\prime Q' x_j = x_i^T K'^\star R_i^T R_j Q' x_j $$
Basically LieRE devolves into something similar to RoPE with learnable phases in the case of 1D inputs. If this passes some scrutiny, it would be worth including in the paper with more polished exposition and credits to you @lucidrains for helping figure this out.
@baxelrod hey, can't say I totally get all that but I will retry with a head to head against rotary vision transformer
I really do want this to work!
@lucidrains, where is your implementation? I can test and compare the 3D (spatiotemporal) ViTs I train with. They are big models though. Varying between ViT-H and ViT-G. Maybe I can experiment and compare to axial for a small test case with less data and smaller model.
@tasansal yea that sounds good, i can get it to you in a short gist < 100 loc
do you have strong baselines with axial rotary positions already?
@tasansal yea that sounds good, i can get it to you in a short gist < 100 loc
do you have strong baselines with axial rotary positions already?
I have 3 models trained with MAE technique for 3D scientific images and benchmarked on on one downstream task (semantic segmentation). Not "strong" by any means, but its something.
However, they are very large models that were expensive to train (time + GPU). I may want to train smaller variants both ways with axial and LieRE, which would still take some time.
I could also train the vanilla 2D ViT tiny example with axial-rope and liere on huggingface for u as well.
@tasansal yes that would be great! just realized this morning I'm out of cloud compute credits, so would be much appreciated and a time saver
let me get the code to you in a few hours
@tasansal yes that would be great! just realized this morning I'm out of cloud compute credits, so would be much appreciated and a time saver
let me get the code to you in a few hours
Great thanks! Where can one get "free" cloud compute credits 😄
@tasansal yea that sounds good, i can get it to you in a short gist < 100 loc do you have strong baselines with axial rotary positions already?
I have 3 models trained with MAE technique for 3D scientific images and benchmarked on on one downstream task (semantic segmentation). Not "strong" by any means, but its something.
However, they are very large models that were expensive to train (time + GPU). I may want to train smaller variants both ways with axial and LieRE, which would still take some time.
@tasansal out of curiosity, what dataset were you using? We did some testing with video data and 3D brain MRIs and are curious what other 3D datasets might make sense. Training those models on the data we had was expensive...
@tasansal hey, sorry for the delay, you can try it by copy pasting the module and function from this gist. you'll need to instantiate this in your transformer module and then pass in the rotations for all your attention modules on forward and apply them to the queries and keys
change num_dim = 3 for video
i'm not seeing it unfortunately, unless if somehow there is an error in the gist supplied above
this is a small imagewoof dataset on rotary vit as baseline
Thanks for sharing! We're also working on cleaning up our code, and will be open to sharing in ~1-2 weeks. Certain conferences won't let us post the code publicly until after publication, but we're happy to share via DM if you can wait a week or two for us to get the code to the point where it's not too embarrassing.
@SophieOstmeier no problem, do let me know if you see any blatant errors
@SophieOstmeier have you shared the code yet? 🙏
@bainjamain you can try it out by copy pasting this