rotary-embedding-torch icon indicating copy to clipboard operation
rotary-embedding-torch copied to clipboard

Slower than absolute positional embeddings?

Open umarbutler opened this issue 1 year ago • 4 comments

Hi @lucidrains, Thanks for creating this wonderful package as well as x-transformers. I wanted to understand why rotary embeddings seem to be slower for me than absolute positional embeddings. I'm working with a BERT-like model and I have benchmarked absolute positional embeddings against rotary embeddings with a batch of 64 exactly 512-token long sequences and I have found the absolute positional embeddings to be faster. Using line profiler, I can see that most of the time (>50%) is spent on the line (offset + seq_len) <= self.cached_freqs_seq_len.item() in RotaryEmbedding.forward() (https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py#L290).

In terms of how I am using it, see the snippet below (with certain irrelevant code omitted):

# Copyright 2024 Umar Butler. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
class SelfAttention(nn.Module):
    def __init__(self, config):
        ...
        self.rotary_emb = RotaryEmbedding(
            dim = self.attention_head_size // 2,
            freqs_for = 'lang',
        )
        ...

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        rotator = self.rotary_emb.rotate_queries_or_keys
        
        # Normalize the hidden states.
        query = self.transpose_for_scores(self.query(hidden_states))
        key = self.transpose_for_scores(self.key(hidden_states))
        value = self.transpose_for_scores(self.value(hidden_states))
        
        # Rotate the query and key vectors.
        query = rotator(query)
        key = rotator(key)
        
        # Use the appropriate accelerated attention implementation.
        attention, attention_probs = self.attend(
            query = query,
            key = key,
            value = value,
            attention_mask = attention_mask,
            head_mask = head_mask,
            dropout_prob = self.dropout_prob if self.training else 0.0,
            output_attentions = output_attentions,
        )
        
        attention = attention.view(attention.size()[:-2] + (self.all_head_size,))
        
        return (attention,) if not output_attentions else (attention, attention_probs)

Would you happen to have any idea as to what could be causing this? Is this expected behaviour? I'm not sure if it could be that 512-tokens is not enough to realise the benefits of rotary embeddings? I do intend on training it with more than that but I want to be sure it will be more performant before I do so.

umarbutler avatar Sep 20 '24 16:09 umarbutler