mlx icon indicating copy to clipboard operation
mlx copied to clipboard

cache costheta and sintheta in positional_encoding.py to reduce memory usage and computation

Open mzbac opened this issue 1 year ago • 4 comments

In the current implementation of Rope, we are not caching the costheta and sintheta. Instead, we compute the theta on-the-fly for every forward pass. It would be better if we added a cache for this to avoid unnecessary computation and reduce memory usage by avoiding initializing the mx array in the forward pass. I am happy to work on this improvement if you guys think it is something we would like to do.

mzbac avatar Jan 04 '24 06:01 mzbac

Yea 💯 . Was just discussing this the other day with @angeloskath (he may already have a diff not sure). I'm happy to have it as a contribution if we don't have something in the pipe already!

@angeloskath you had some thoughts about a good way to implement this. Maybe you can add that here (assuming you didn't do it already).

awni avatar Jan 04 '24 14:01 awni

I played around a bit. Most likely, I will follow the implementation of Transformers if you guys don't have any objections. Here is the draft (just for demo caching) that I tested locally, and it seems to be working. Please let me know your thoughts.

class RoPE(nn.RoPE):
    def __init__(
        self,
        dims: int,
        max_position_embeddings: int = 2048,
        rope_scaling_factor: float = 1.0,
        base: float = 10000,
        dtype=mx.float32,
    ):
        super().__init__(dims)
        self.base = base
        self.rope_scaling_factor = rope_scaling_factor

        # according to the paper, the head dimension should be even
        assert dims % 2 == 0, "dims must be divisble by 2"

        D = dims // 2
        self.freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(self.base) / D))
        self._set_cos_sin_cache(max_position_embeddings, dtype=dtype)

    def __call__(self, x, offset: int = 0):
        shape = x.shape
        x = mx.reshape(x, (-1, shape[-2], shape[-1]))
        N = x.shape[1] + offset
        if N > self.max_seq_len_cached:
            self._set_cos_sin_cache(
                seq_len=N, dtype=x.dtype
            )  # in case the default max_position_embeddings is incorret, make sure the backward compatibility

        rx = self._compute_rope(self.costheta[offset:N], self.sintheta[offset:N], x)

        return mx.reshape(rx, shape)

    def _set_cos_sin_cache(self, seq_len: int, dtype=mx.float32):
        self.max_seq_len_cached = seq_len
        positions = mx.arange(0, self.max_seq_len_cached, dtype=dtype)
        positions = positions / self.rope_scaling_factor
        theta = mx.reshape(positions, (-1, 1)) * mx.reshape(self.freqs, (1, -1))
        self.costheta = mx.cos(theta)
        self.sintheta = mx.sin(theta)

mzbac avatar Jan 04 '24 15:01 mzbac

Hey yeah. This would be a perfect use for "constant" protected members. I had it thusly in ALiBi so @mzbac feel free to add the following suggestion in the ALiBi implementation as well.

The way to do this efficiently would be to have a class variable that caches the costheta and sintheta as follows:

class RoPE(Module):

    # Class members implementing the cache
    _cos_sin_theta_key = None
    _cos_sin_theta_value = None

    @classmethod
    def create_cos_sin_theta(
        cls,
        N: int,
        D: int,
        offset: int = 0,
        base: float = 10000,
        scale: float = 1.0,
        dtype=mx.float32,
    ):
        if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
            # compute cos(θ) and sin(θ) same as it is now
            cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
            cls._cos_sin_theta_value = (costheta, sintheta)
        return cls._cos_sin_theta_value

This means that it will be cached for all RoPE instances and it would save a bit of memory and time during a generation as it would only be created once in the first layer.

angeloskath avatar Jan 04 '24 23:01 angeloskath

@angeloskath, thanks for the suggestions. having "constant" protected members makes sense to me, but I am not very familiar with ALiBi yet. I might start working on the Rope first and if it goes well, I will try to add it to ALiBi as well.

mzbac avatar Jan 05 '24 00:01 mzbac

Closing the issue since the PR is merged.

angeloskath avatar Jan 07 '24 11:01 angeloskath