rotary-embedding-torch
rotary-embedding-torch copied to clipboard
Slower than absolute positional embeddings?
Hi @lucidrains,
Thanks for creating this wonderful package as well as x-transformers. I wanted to understand why rotary embeddings seem to be slower for me than absolute positional embeddings. I'm working with a BERT-like model and I have benchmarked absolute positional embeddings against rotary embeddings with a batch of 64 exactly 512-token long sequences and I have found the absolute positional embeddings to be faster. Using line profiler, I can see that most of the time (>50%) is spent on the line (offset + seq_len) <= self.cached_freqs_seq_len.item() in RotaryEmbedding.forward() (https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py#L290).
In terms of how I am using it, see the snippet below (with certain irrelevant code omitted):
# Copyright 2024 Umar Butler. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
class SelfAttention(nn.Module):
def __init__(self, config):
...
self.rotary_emb = RotaryEmbedding(
dim = self.attention_head_size // 2,
freqs_for = 'lang',
)
...
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
rotator = self.rotary_emb.rotate_queries_or_keys
# Normalize the hidden states.
query = self.transpose_for_scores(self.query(hidden_states))
key = self.transpose_for_scores(self.key(hidden_states))
value = self.transpose_for_scores(self.value(hidden_states))
# Rotate the query and key vectors.
query = rotator(query)
key = rotator(key)
# Use the appropriate accelerated attention implementation.
attention, attention_probs = self.attend(
query = query,
key = key,
value = value,
attention_mask = attention_mask,
head_mask = head_mask,
dropout_prob = self.dropout_prob if self.training else 0.0,
output_attentions = output_attentions,
)
attention = attention.view(attention.size()[:-2] + (self.all_head_size,))
return (attention,) if not output_attentions else (attention, attention_probs)
Would you happen to have any idea as to what could be causing this? Is this expected behaviour? I'm not sure if it could be that 512-tokens is not enough to realise the benefits of rotary embeddings? I do intend on training it with more than that but I want to be sure it will be more performant before I do so.