keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Potential RotaryEmbedding error?

Open chenmoneygithub opened this issue 1 year ago • 3 comments

Hi team,

I was checking the implementation of RotaryEmbedding layer, and was a bit confused at the following computation:

def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
    x1, x2 = ops.split(tensor, 2, axis=self.feature_axis)
    half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis)
    return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

Here ops.split(tensor, 2, axis=self.feature_axis) should be splitting the tensor by the middle, i.e., [0, 1, 2, 3] => [0, 1], [2, 3], but from the paper it should be [0, 2], [1, 3] Screenshot 2024-01-28 at 3 13 23 PM . Can I get some clarification? thanks!

chenmoneygithub avatar Jan 28 '24 23:01 chenmoneygithub

@mattdangerw Since LLaMa exists in KerasNLP now, I assume the routine input/output checks have been done, so I am more confused about this implementation.

here is how another repo do the half rotation:

def rotate_half(x):
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d r -> ... (d r)')

chenmoneygithub avatar Jan 28 '24 23:01 chenmoneygithub

@chenmoneygithub hi!!

Will take a look soon! cc @shivance in case he has thoughts (he wrote this originally)

mattdangerw avatar Jan 29 '24 22:01 mattdangerw

Hi Chen,

Here ops.split(tensor, 2, axis=self.feature_axis) should be splitting the tensor by the middle, i.e., [0, 1, 2, 3] => [0, 1], [2, 3], but from the paper it should be [0, 2], [1, 3]

This is the fundamental difference beween the original PyTorch implementation [1] of the rotary embeddings and HuggingFace's implementation [2]. They should still compute the same thing if the weights are permuted correctly [3].

LLaMA ported the weigghts over from Huggingface directly so the outputs of the KerasNLP model and the Huggingface model matched [4]. On the other hand, Mistral was ported from the original PyTorch implementation and I had to add this thin layer around the rotary embeddings layer in KerasNLP for the outputs to match.

tirthasheshpatel avatar Feb 05 '24 21:02 tirthasheshpatel