EAGLE icon indicating copy to clipboard operation
EAGLE copied to clipboard

Inquiry about attention mask used for EAGLE-3 Training

Open YanzuoLu opened this issue 6 months ago • 3 comments

May I ask why does the attention mask also need to be shifted during the training-time test process? From what I've read in the code, it seems the attention mask is only applied to the KV cache stored at step=0. If that's the case, wouldn't it be the same as what is described in Figure 6 of the paper if we don't shift the mask?

https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L844

Thanks for your time.

YanzuoLu avatar Jun 27 '25 06:06 YanzuoLu

This masking is extremely weird, it would gradually make the past unseen for input tokens. I make a demo mask code below

import torch
seq_len = 6
idx = 1
attn_mask = torch.tril(torch.ones((seq_len, seq_len)))
index = torch.arange(seq_len)
attn_mask[index[idx:], index[:seq_len-idx]] = -100
print(attn_mask)

output is like this:

tensor([[   1.,    0.,    0.,    0.,    0.,    0.],
        [-100.,    1.,    0.,    0.,    0.,    0.],
        [   1., -100.,    1.,    0.,    0.,    0.],
        [   1.,    1., -100.,    1.,    0.,    0.],
        [   1.,    1.,    1., -100.,    1.,    0.],
        [   1.,    1.,    1.,    1., -100.,    1.]])

according to the code, this weights apply to current q and the first k, i.e. k0. So the input token is [1, 2, 3, 4, 5, X], and the past key is [0, 1, 2, 3, 4, 5], this masking means, the token 1 would attend token 0, which is normal; but the token 2 would attend wotken 1, but not attend the token 0, which would being weird. I can't run the code right now, maybe there is something missing, it would be great to hear some explaination.🤔 @Liyuhui-12 @hongyanz

DeclK avatar Jun 30 '25 10:06 DeclK

That's true, I think this may be a tiny bug. If we comment line841-line844 in cnets.py, the performance would be slightly better in my experiments. But since the seqlen here is 2k, and the shift is only 7 steps, most tokens in the sequence are minimally affected. The performance comparison chart is shown below (fix-test means comment the line841-line844):

Image

xinlong-yang avatar Jul 01 '25 09:07 xinlong-yang

same question here. @hongyanz is this a bug? Code link https://github.com/SafeAILab/EAGLE/blob/a19206e0edfaa13d835027eb28933631ce89aa5c/eagle/traineagle3/cnets.py#L865-L868

SunnyLi2015 avatar Sep 09 '25 10:09 SunnyLi2015