cache costheta and sintheta in positional_encoding.py to reduce memory usage and computation
In the current implementation of Rope, we are not caching the costheta and sintheta. Instead, we compute the theta on-the-fly for every forward pass. It would be better if we added a cache for this to avoid unnecessary computation and reduce memory usage by avoiding initializing the mx array in the forward pass. I am happy to work on this improvement if you guys think it is something we would like to do.
Yea 💯 . Was just discussing this the other day with @angeloskath (he may already have a diff not sure). I'm happy to have it as a contribution if we don't have something in the pipe already!
@angeloskath you had some thoughts about a good way to implement this. Maybe you can add that here (assuming you didn't do it already).
I played around a bit. Most likely, I will follow the implementation of Transformers if you guys don't have any objections. Here is the draft (just for demo caching) that I tested locally, and it seems to be working. Please let me know your thoughts.
class RoPE(nn.RoPE):
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
rope_scaling_factor: float = 1.0,
base: float = 10000,
dtype=mx.float32,
):
super().__init__(dims)
self.base = base
self.rope_scaling_factor = rope_scaling_factor
# according to the paper, the head dimension should be even
assert dims % 2 == 0, "dims must be divisble by 2"
D = dims // 2
self.freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(self.base) / D))
self._set_cos_sin_cache(max_position_embeddings, dtype=dtype)
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
if N > self.max_seq_len_cached:
self._set_cos_sin_cache(
seq_len=N, dtype=x.dtype
) # in case the default max_position_embeddings is incorret, make sure the backward compatibility
rx = self._compute_rope(self.costheta[offset:N], self.sintheta[offset:N], x)
return mx.reshape(rx, shape)
def _set_cos_sin_cache(self, seq_len: int, dtype=mx.float32):
self.max_seq_len_cached = seq_len
positions = mx.arange(0, self.max_seq_len_cached, dtype=dtype)
positions = positions / self.rope_scaling_factor
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(self.freqs, (1, -1))
self.costheta = mx.cos(theta)
self.sintheta = mx.sin(theta)
Hey yeah. This would be a perfect use for "constant" protected members. I had it thusly in ALiBi so @mzbac feel free to add the following suggestion in the ALiBi implementation as well.
The way to do this efficiently would be to have a class variable that caches the costheta and sintheta as follows:
class RoPE(Module):
# Class members implementing the cache
_cos_sin_theta_key = None
_cos_sin_theta_value = None
@classmethod
def create_cos_sin_theta(
cls,
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
# compute cos(θ) and sin(θ) same as it is now
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (costheta, sintheta)
return cls._cos_sin_theta_value
This means that it will be cached for all RoPE instances and it would save a bit of memory and time during a generation as it would only be created once in the first layer.
@angeloskath, thanks for the suggestions. having "constant" protected members makes sense to me, but I am not very familiar with ALiBi yet. I might start working on the Rope first and if it goes well, I will try to add it to ALiBi as well.
Closing the issue since the PR is merged.