cerebros-core-algorithm-alpha
cerebros-core-algorithm-alpha copied to clipboard
Create an iRoPE embedding in PyTorch
Create an iRoPE embedding in PyTorch
Proposed solution to test onece PyTorch cerebros model is ready
import math
import torch
import torch.nn as nn
# ------------- RotaryEmbedding -------------
class RotaryEmbedding(nn.Module):
"""
Generates the (sin, cos) tensors used for the interleaved RoPE
described in the original RoPE paper.
"""
def __init__(self, dim: int, max_seq_len: int = 1024, temperature: float = 10000.0):
super().__init__()
if dim % 2 != 0:
raise ValueError(
f"Embedding dimension `dim` ({dim}) must be even for RotaryEmbedding."
)
self.dim = dim
self.max_seq_len = max_seq_len
self.temperature = float(temperature)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args
----
x : FloatTensor (B, T, C) – the incoming activations
Returns
-------
sin, cos : (B, T, C) – broadcast-ready sin / cos tensors
"""
B, T, _ = x.shape
device, dtype = x.device, x.dtype
# Compute inverse frequencies [dim/2]
inv_freq = 1.0 / (
self.temperature ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
# Positions [T]
position = torch.arange(T, device=device, dtype=torch.float32)
# Outer product → [T, dim/2]
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
# Repeat to interleave: [a, b] → [a, a, b, b, ...]
sin = torch.sin(sinusoid_inp).repeat_interleave(2, dim=-1)
cos = torch.cos(sinusoid_inp).repeat_interleave(2, dim=-1)
# Add batch dimension and broadcast
sin = sin.unsqueeze(0).expand(B, T, -1).to(dtype)
cos = cos.unsqueeze(0).expand(B, T, -1).to(dtype)
return sin, cos
# ------------- InterleavedRoPE -------------
class InterleavedRoPE(nn.Module):
"""
Applies rotary positional embeddings to the input tensor.
"""
def __init__(self, dim: int, max_seq_len: int = 1024):
super().__init__()
if dim % 2 != 0:
raise ValueError(
f"Embedding dimension `dim` ({dim}) must be even for InterleavedRoPE."
)
self.dim = dim
self.max_seq_len = max_seq_len
self.rotary_emb = RotaryEmbedding(dim, max_seq_len)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args
----
x : FloatTensor (B, T, C)
Returns
-------
FloatTensor (B, T, C)
"""
sin, cos = self.rotary_emb(x)
return apply_rotary_pos_emb(x, sin, cos)
# ------------- Helpers (exact Keras semantics) -------------
def split_alternate(x: torch.Tensor) -> torch.Tensor:
"""
Re-arranges the last dimension so that the even and odd halves
are swapped: [a0, b0, a1, b1, ...] -> [a0, a1, ..., b0, b1, ...]
"""
B, T, C = x.size()
x = x.view(B, T, C // 2, 2).transpose(-2, -1).contiguous()
return x.view(B, T, C)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
Rotate the second half of the vector by π.
"""
x = split_alternate(x)
d = x.size(-1)
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
"""
Apply the RoPE formula: x' = x * cos + rotate_half(x) * sin
"""
x = x.float() # ensure fp32 for numerical stability
return (x * cos) + (rotate_half(x) * sin)