rotary-embedding-torch icon indicating copy to clipboard operation
rotary-embedding-torch copied to clipboard

why dim of q be different from dim of RotaryEmbedding

Open HiSultryMan opened this issue 2 years ago • 2 comments

In your demo code, dim of q is 64 while dim of RotaryEmbedding is 32. I checked the code, q with position index larger than 32 will not be rotate at all. Confused.

  rotary_emb = RotaryEmbedding(dim = 32)
  
  # mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)
  
  q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
  k = torch.randn(1, 8, 1024, 64) # keys
  
  # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)
  
  q = rotary_emb.rotate_queries_or_keys(q)
  k = rotary_emb.rotate_queries_or_keys(k)

HiSultryMan avatar Apr 12 '23 09:04 HiSultryMan

its only applied to half of the feature dimension. here is another implementation from a kaggle post:

import torch

class Rotary(torch.nn.Module):
    
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, None, :]
            self.sin_cached = emb.sin()[:, None, None, :]
        return self.cos_cached, self.sin_cached
        
# rotary pos emb helpers:
def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0

@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

thorinf avatar Apr 13 '23 16:04 thorinf

I also had the same question. Based on the discussion from lucidrains/x-transformers#40, it seems that doing a partial rotation can slightly improve performance.

I think it could help with clearing future confusion to add a brief comment about this in the README!

alstonlo avatar Jul 08 '23 19:07 alstonlo